• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import time
16import threading
17import unittest
18import platform
19
20from grpc._cython import cygrpc
21from tests.unit._cython import test_utilities
22from tests.unit import test_common
23from tests.unit import resources
24
25_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
26_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
27_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
28_EMPTY_FLAGS = 0
29
30
31def _metadata_plugin(context, callback):
32    callback(((
33        _CALL_CREDENTIALS_METADATA_KEY,
34        _CALL_CREDENTIALS_METADATA_VALUE,
35    ),), cygrpc.StatusCode.ok, b'')
36
37
38class TypeSmokeTest(unittest.TestCase):
39
40    def testCompletionQueueUpDown(self):
41        completion_queue = cygrpc.CompletionQueue()
42        del completion_queue
43
44    def testServerUpDown(self):
45        server = cygrpc.Server(set([(
46            b'grpc.so_reuseport',
47            0,
48        )]), False)
49        del server
50
51    def testChannelUpDown(self):
52        channel = cygrpc.Channel(b'[::]:0', None, None)
53        channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
54
55    def test_metadata_plugin_call_credentials_up_down(self):
56        cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
57                                             b'test plugin name!')
58
59    def testServerStartNoExplicitShutdown(self):
60        server = cygrpc.Server([(
61            b'grpc.so_reuseport',
62            0,
63        )], False)
64        completion_queue = cygrpc.CompletionQueue()
65        server.register_completion_queue(completion_queue)
66        port = server.add_http2_port(b'[::]:0')
67        self.assertIsInstance(port, int)
68        server.start()
69        del server
70
71    def testServerStartShutdown(self):
72        completion_queue = cygrpc.CompletionQueue()
73        server = cygrpc.Server([
74            (
75                b'grpc.so_reuseport',
76                0,
77            ),
78        ], False)
79        server.add_http2_port(b'[::]:0')
80        server.register_completion_queue(completion_queue)
81        server.start()
82        shutdown_tag = object()
83        server.shutdown(completion_queue, shutdown_tag)
84        event = completion_queue.poll()
85        self.assertEqual(cygrpc.CompletionType.operation_complete,
86                         event.completion_type)
87        self.assertIs(shutdown_tag, event.tag)
88        del server
89        del completion_queue
90
91
92class ServerClientMixin(object):
93
94    def setUpMixin(self, server_credentials, client_credentials, host_override):
95        self.server_completion_queue = cygrpc.CompletionQueue()
96        self.server = cygrpc.Server([(
97            b'grpc.so_reuseport',
98            0,
99        )], False)
100        self.server.register_completion_queue(self.server_completion_queue)
101        if server_credentials:
102            self.port = self.server.add_http2_port(b'[::]:0',
103                                                   server_credentials)
104        else:
105            self.port = self.server.add_http2_port(b'[::]:0')
106        self.server.start()
107        self.client_completion_queue = cygrpc.CompletionQueue()
108        if client_credentials:
109            client_channel_arguments = ((
110                cygrpc.ChannelArgKey.ssl_target_name_override,
111                host_override,
112            ),)
113            self.client_channel = cygrpc.Channel(
114                'localhost:{}'.format(self.port).encode(),
115                client_channel_arguments, client_credentials)
116        else:
117            self.client_channel = cygrpc.Channel(
118                'localhost:{}'.format(self.port).encode(), set(), None)
119        if host_override:
120            self.host_argument = None  # default host
121            self.expected_host = host_override
122        else:
123            # arbitrary host name necessitating no further identification
124            self.host_argument = b'hostess'
125            self.expected_host = self.host_argument
126
127    def tearDownMixin(self):
128        self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
129        del self.client_channel
130        del self.server
131        del self.client_completion_queue
132        del self.server_completion_queue
133
134    def _perform_queue_operations(self, operations, call, queue, deadline,
135                                  description):
136        """Perform the operations with given call, queue, and deadline.
137
138        Invocation errors are reported with as an exception with `description`
139        in the message. Performs the operations asynchronously, returning a
140        future.
141        """
142
143        def performer():
144            tag = object()
145            try:
146                call_result = call.start_client_batch(operations, tag)
147                self.assertEqual(cygrpc.CallError.ok, call_result)
148                event = queue.poll(deadline=deadline)
149                self.assertEqual(cygrpc.CompletionType.operation_complete,
150                                 event.completion_type)
151                self.assertTrue(event.success)
152                self.assertIs(tag, event.tag)
153            except Exception as error:
154                raise Exception("Error in '{}': {}".format(
155                    description, error.message))
156            return event
157
158        return test_utilities.SimpleFuture(performer)
159
160    def test_echo(self):
161        DEADLINE = time.time() + 5
162        DEADLINE_TOLERANCE = 0.25
163        CLIENT_METADATA_ASCII_KEY = 'key'
164        CLIENT_METADATA_ASCII_VALUE = 'val'
165        CLIENT_METADATA_BIN_KEY = 'key-bin'
166        CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
167        SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
168        SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
169        SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
170        SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
171        SERVER_STATUS_CODE = cygrpc.StatusCode.ok
172        SERVER_STATUS_DETAILS = 'our work is never over'
173        REQUEST = b'in death a member of project mayhem has a name'
174        RESPONSE = b'his name is robert paulson'
175        METHOD = b'twinkies'
176
177        server_request_tag = object()
178        request_call_result = self.server.request_call(
179            self.server_completion_queue, self.server_completion_queue,
180            server_request_tag)
181
182        self.assertEqual(cygrpc.CallError.ok, request_call_result)
183
184        client_call_tag = object()
185        client_initial_metadata = (
186            (
187                CLIENT_METADATA_ASCII_KEY,
188                CLIENT_METADATA_ASCII_VALUE,
189            ),
190            (
191                CLIENT_METADATA_BIN_KEY,
192                CLIENT_METADATA_BIN_VALUE,
193            ),
194        )
195        client_call = self.client_channel.integrated_call(
196            0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
197            None, [
198                (
199                    [
200                        cygrpc.SendInitialMetadataOperation(
201                            client_initial_metadata, _EMPTY_FLAGS),
202                        cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
203                        cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
204                        cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
205                        cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
206                        cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
207                    ],
208                    client_call_tag,
209                ),
210            ])
211        client_event_future = test_utilities.SimpleFuture(
212            self.client_channel.next_call_event)
213
214        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
215        self.assertEqual(cygrpc.CompletionType.operation_complete,
216                         request_event.completion_type)
217        self.assertIsInstance(request_event.call, cygrpc.Call)
218        self.assertIs(server_request_tag, request_event.tag)
219        self.assertTrue(
220            test_common.metadata_transmitted(client_initial_metadata,
221                                             request_event.invocation_metadata))
222        self.assertEqual(METHOD, request_event.call_details.method)
223        self.assertEqual(self.expected_host, request_event.call_details.host)
224        self.assertLess(abs(DEADLINE - request_event.call_details.deadline),
225                        DEADLINE_TOLERANCE)
226
227        server_call_tag = object()
228        server_call = request_event.call
229        server_initial_metadata = ((
230            SERVER_INITIAL_METADATA_KEY,
231            SERVER_INITIAL_METADATA_VALUE,
232        ),)
233        server_trailing_metadata = ((
234            SERVER_TRAILING_METADATA_KEY,
235            SERVER_TRAILING_METADATA_VALUE,
236        ),)
237        server_start_batch_result = server_call.start_server_batch([
238            cygrpc.SendInitialMetadataOperation(server_initial_metadata,
239                                                _EMPTY_FLAGS),
240            cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
241            cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
242            cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
243            cygrpc.SendStatusFromServerOperation(
244                server_trailing_metadata, SERVER_STATUS_CODE,
245                SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
246        ], server_call_tag)
247        self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
248
249        server_event = self.server_completion_queue.poll(deadline=DEADLINE)
250        client_event = client_event_future.result()
251
252        self.assertEqual(6, len(client_event.batch_operations))
253        found_client_op_types = set()
254        for client_result in client_event.batch_operations:
255            # we expect each op type to be unique
256            self.assertNotIn(client_result.type(), found_client_op_types)
257            found_client_op_types.add(client_result.type())
258            if client_result.type(
259            ) == cygrpc.OperationType.receive_initial_metadata:
260                self.assertTrue(
261                    test_common.metadata_transmitted(
262                        server_initial_metadata,
263                        client_result.initial_metadata()))
264            elif client_result.type() == cygrpc.OperationType.receive_message:
265                self.assertEqual(RESPONSE, client_result.message())
266            elif client_result.type(
267            ) == cygrpc.OperationType.receive_status_on_client:
268                self.assertTrue(
269                    test_common.metadata_transmitted(
270                        server_trailing_metadata,
271                        client_result.trailing_metadata()))
272                self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
273                self.assertEqual(SERVER_STATUS_CODE, client_result.code())
274        self.assertEqual(
275            set([
276                cygrpc.OperationType.send_initial_metadata,
277                cygrpc.OperationType.send_message,
278                cygrpc.OperationType.send_close_from_client,
279                cygrpc.OperationType.receive_initial_metadata,
280                cygrpc.OperationType.receive_message,
281                cygrpc.OperationType.receive_status_on_client
282            ]), found_client_op_types)
283
284        self.assertEqual(5, len(server_event.batch_operations))
285        found_server_op_types = set()
286        for server_result in server_event.batch_operations:
287            self.assertNotIn(server_result.type(), found_server_op_types)
288            found_server_op_types.add(server_result.type())
289            if server_result.type() == cygrpc.OperationType.receive_message:
290                self.assertEqual(REQUEST, server_result.message())
291            elif server_result.type(
292            ) == cygrpc.OperationType.receive_close_on_server:
293                self.assertFalse(server_result.cancelled())
294        self.assertEqual(
295            set([
296                cygrpc.OperationType.send_initial_metadata,
297                cygrpc.OperationType.receive_message,
298                cygrpc.OperationType.send_message,
299                cygrpc.OperationType.receive_close_on_server,
300                cygrpc.OperationType.send_status_from_server
301            ]), found_server_op_types)
302
303        del client_call
304        del server_call
305
306    def test_6522(self):
307        DEADLINE = time.time() + 5
308        DEADLINE_TOLERANCE = 0.25
309        METHOD = b'twinkies'
310
311        empty_metadata = ()
312
313        # Prologue
314        server_request_tag = object()
315        self.server.request_call(self.server_completion_queue,
316                                 self.server_completion_queue,
317                                 server_request_tag)
318        client_call = self.client_channel.segregated_call(
319            0, METHOD, self.host_argument, DEADLINE, None, None,
320            ([(
321                [
322                    cygrpc.SendInitialMetadataOperation(empty_metadata,
323                                                        _EMPTY_FLAGS),
324                    cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
325                ],
326                object(),
327            ),
328              (
329                  [
330                      cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
331                  ],
332                  object(),
333              )]))
334
335        client_initial_metadata_event_future = test_utilities.SimpleFuture(
336            client_call.next_event)
337
338        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
339        server_call = request_event.call
340
341        def perform_server_operations(operations, description):
342            return self._perform_queue_operations(operations, server_call,
343                                                  self.server_completion_queue,
344                                                  DEADLINE, description)
345
346        server_event_future = perform_server_operations([
347            cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
348        ], "Server prologue")
349
350        client_initial_metadata_event_future.result()  # force completion
351        server_event_future.result()
352
353        # Messaging
354        for _ in range(10):
355            client_call.operate([
356                cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
357                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
358            ], "Client message")
359            client_message_event_future = test_utilities.SimpleFuture(
360                client_call.next_event)
361            server_event_future = perform_server_operations([
362                cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
363                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
364            ], "Server receive")
365
366            client_message_event_future.result()  # force completion
367            server_event_future.result()
368
369        # Epilogue
370        client_call.operate([
371            cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
372        ], "Client epilogue")
373        # One for ReceiveStatusOnClient, one for SendCloseFromClient.
374        client_events_future = test_utilities.SimpleFuture(lambda: {
375            client_call.next_event(),
376            client_call.next_event(),
377        })
378
379        server_event_future = perform_server_operations([
380            cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
381            cygrpc.SendStatusFromServerOperation(
382                empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
383        ], "Server epilogue")
384
385        client_events_future.result()  # force completion
386        server_event_future.result()
387
388
389class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
390
391    def setUp(self):
392        self.setUpMixin(None, None, None)
393
394    def tearDown(self):
395        self.tearDownMixin()
396
397
398class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
399
400    def setUp(self):
401        server_credentials = cygrpc.server_credentials_ssl(
402            None, [
403                cygrpc.SslPemKeyCertPair(resources.private_key(),
404                                         resources.certificate_chain())
405            ], False)
406        client_credentials = cygrpc.SSLChannelCredentials(
407            resources.test_root_certificates(), None, None)
408        self.setUpMixin(server_credentials, client_credentials,
409                        _SSL_HOST_OVERRIDE)
410
411    def tearDown(self):
412        self.tearDownMixin()
413
414
415if __name__ == '__main__':
416    unittest.main(verbosity=2)
417