• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server certificate rotation.
15
16Here we test various aspects of gRPC Python, and in some cases gRPC
17Core by extension, support for server certificate rotation.
18
19* ServerSSLCertReloadTestWithClientAuth: test ability to rotate
20  server's SSL cert for use in future channels with clients while not
21  affecting any existing channel. The server requires client
22  authentication.
23
24* ServerSSLCertReloadTestWithoutClientAuth: like
25  ServerSSLCertReloadTestWithClientAuth except that the server does
26  not authenticate the client.
27
28* ServerSSLCertReloadTestCertConfigReuse: tests gRPC Python's ability
29  to deal with user's reuse of ServerCertificateConfiguration instances.
30"""
31
32import abc
33import collections
34import os
35import six
36import threading
37import unittest
38import logging
39
40from concurrent import futures
41
42import grpc
43from tests.unit import resources
44from tests.unit import test_common
45from tests.testing import _application_common
46from tests.testing import _server_application
47from tests.testing.proto import services_pb2_grpc
48
49CA_1_PEM = resources.cert_hier_1_root_ca_cert()
50CA_2_PEM = resources.cert_hier_2_root_ca_cert()
51
52CLIENT_KEY_1_PEM = resources.cert_hier_1_client_1_key()
53CLIENT_CERT_CHAIN_1_PEM = (resources.cert_hier_1_client_1_cert() +
54                           resources.cert_hier_1_intermediate_ca_cert())
55
56CLIENT_KEY_2_PEM = resources.cert_hier_2_client_1_key()
57CLIENT_CERT_CHAIN_2_PEM = (resources.cert_hier_2_client_1_cert() +
58                           resources.cert_hier_2_intermediate_ca_cert())
59
60SERVER_KEY_1_PEM = resources.cert_hier_1_server_1_key()
61SERVER_CERT_CHAIN_1_PEM = (resources.cert_hier_1_server_1_cert() +
62                           resources.cert_hier_1_intermediate_ca_cert())
63
64SERVER_KEY_2_PEM = resources.cert_hier_2_server_1_key()
65SERVER_CERT_CHAIN_2_PEM = (resources.cert_hier_2_server_1_cert() +
66                           resources.cert_hier_2_intermediate_ca_cert())
67
68# for use with the CertConfigFetcher. Roughly a simple custom mock
69# implementation
70Call = collections.namedtuple('Call', ['did_raise', 'returned_cert_config'])
71
72
73def _create_channel(port, credentials):
74    return grpc.secure_channel('localhost:{}'.format(port), credentials)
75
76
77def _create_client_stub(channel, expect_success):
78    if expect_success:
79        # per Nathaniel: there's some robustness issue if we start
80        # using a channel without waiting for it to be actually ready
81        grpc.channel_ready_future(channel).result(timeout=10)
82    return services_pb2_grpc.FirstServiceStub(channel)
83
84
85class CertConfigFetcher(object):
86
87    def __init__(self):
88        self._lock = threading.Lock()
89        self._calls = []
90        self._should_raise = False
91        self._cert_config = None
92
93    def reset(self):
94        with self._lock:
95            self._calls = []
96            self._should_raise = False
97            self._cert_config = None
98
99    def configure(self, should_raise, cert_config):
100        assert not (should_raise and cert_config), (
101            "should not specify both should_raise and a cert_config at the same time"
102        )
103        with self._lock:
104            self._should_raise = should_raise
105            self._cert_config = cert_config
106
107    def getCalls(self):
108        with self._lock:
109            return self._calls
110
111    def __call__(self):
112        with self._lock:
113            if self._should_raise:
114                self._calls.append(Call(True, None))
115                raise ValueError('just for fun, should not affect the test')
116            else:
117                self._calls.append(Call(False, self._cert_config))
118                return self._cert_config
119
120
121class _ServerSSLCertReloadTest(
122        six.with_metaclass(abc.ABCMeta, unittest.TestCase)):
123
124    def __init__(self, *args, **kwargs):
125        super(_ServerSSLCertReloadTest, self).__init__(*args, **kwargs)
126        self.server = None
127        self.port = None
128
129    @abc.abstractmethod
130    def require_client_auth(self):
131        raise NotImplementedError()
132
133    def setUp(self):
134        self.server = test_common.test_server()
135        services_pb2_grpc.add_FirstServiceServicer_to_server(
136            _server_application.FirstServiceServicer(), self.server)
137        switch_cert_on_client_num = 10
138        initial_cert_config = grpc.ssl_server_certificate_configuration(
139            [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)],
140            root_certificates=CA_2_PEM)
141        self.cert_config_fetcher = CertConfigFetcher()
142        server_credentials = grpc.dynamic_ssl_server_credentials(
143            initial_cert_config,
144            self.cert_config_fetcher,
145            require_client_authentication=self.require_client_auth())
146        self.port = self.server.add_secure_port('[::]:0', server_credentials)
147        self.server.start()
148
149    def tearDown(self):
150        if self.server:
151            self.server.stop(None)
152
153    def _perform_rpc(self, client_stub, expect_success):
154        # we don't care about the actual response of the rpc; only
155        # whether we can perform it or not, and if not, the status
156        # code must be UNAVAILABLE
157        request = _application_common.UNARY_UNARY_REQUEST
158        if expect_success:
159            response = client_stub.UnUn(request)
160            self.assertEqual(response, _application_common.UNARY_UNARY_RESPONSE)
161        else:
162            with self.assertRaises(grpc.RpcError) as exception_context:
163                client_stub.UnUn(request)
164            # If TLS 1.2 is used, then the client receives an alert message
165            # before the handshake is complete, so the status is UNAVAILABLE. If
166            # TLS 1.3 is used, then the client receives the alert message after
167            # the handshake is complete, so the TSI handshaker returns the
168            # TSI_PROTOCOL_FAILURE result. This result does not have a
169            # corresponding status code, so this yields an UNKNOWN status.
170            self.assertTrue(exception_context.exception.code(
171            ) in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN])
172
173    def _do_one_shot_client_rpc(self,
174                                expect_success,
175                                root_certificates=None,
176                                private_key=None,
177                                certificate_chain=None):
178        credentials = grpc.ssl_channel_credentials(
179            root_certificates=root_certificates,
180            private_key=private_key,
181            certificate_chain=certificate_chain)
182        with _create_channel(self.port, credentials) as client_channel:
183            client_stub = _create_client_stub(client_channel, expect_success)
184            self._perform_rpc(client_stub, expect_success)
185
186    def _test(self):
187        # things should work...
188        self.cert_config_fetcher.configure(False, None)
189        self._do_one_shot_client_rpc(True,
190                                     root_certificates=CA_1_PEM,
191                                     private_key=CLIENT_KEY_2_PEM,
192                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
193        actual_calls = self.cert_config_fetcher.getCalls()
194        self.assertEqual(len(actual_calls), 1)
195        self.assertFalse(actual_calls[0].did_raise)
196        self.assertIsNone(actual_calls[0].returned_cert_config)
197
198        # client should reject server...
199        # fails because client trusts ca2 and so will reject server
200        self.cert_config_fetcher.reset()
201        self.cert_config_fetcher.configure(False, None)
202        self._do_one_shot_client_rpc(False,
203                                     root_certificates=CA_2_PEM,
204                                     private_key=CLIENT_KEY_2_PEM,
205                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
206        actual_calls = self.cert_config_fetcher.getCalls()
207        self.assertGreaterEqual(len(actual_calls), 1)
208        self.assertFalse(actual_calls[0].did_raise)
209        for i, call in enumerate(actual_calls):
210            self.assertFalse(call.did_raise, 'i= {}'.format(i))
211            self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i))
212
213        # should work again...
214        self.cert_config_fetcher.reset()
215        self.cert_config_fetcher.configure(True, None)
216        self._do_one_shot_client_rpc(True,
217                                     root_certificates=CA_1_PEM,
218                                     private_key=CLIENT_KEY_2_PEM,
219                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
220        actual_calls = self.cert_config_fetcher.getCalls()
221        self.assertEqual(len(actual_calls), 1)
222        self.assertTrue(actual_calls[0].did_raise)
223        self.assertIsNone(actual_calls[0].returned_cert_config)
224
225        # if with_client_auth, then client should be rejected by
226        # server because client uses key/cert1, but server trusts ca2,
227        # so server will reject
228        self.cert_config_fetcher.reset()
229        self.cert_config_fetcher.configure(False, None)
230        self._do_one_shot_client_rpc(not self.require_client_auth(),
231                                     root_certificates=CA_1_PEM,
232                                     private_key=CLIENT_KEY_1_PEM,
233                                     certificate_chain=CLIENT_CERT_CHAIN_1_PEM)
234        actual_calls = self.cert_config_fetcher.getCalls()
235        self.assertGreaterEqual(len(actual_calls), 1)
236        for i, call in enumerate(actual_calls):
237            self.assertFalse(call.did_raise, 'i= {}'.format(i))
238            self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i))
239
240        # should work again...
241        self.cert_config_fetcher.reset()
242        self.cert_config_fetcher.configure(False, None)
243        self._do_one_shot_client_rpc(True,
244                                     root_certificates=CA_1_PEM,
245                                     private_key=CLIENT_KEY_2_PEM,
246                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
247        actual_calls = self.cert_config_fetcher.getCalls()
248        self.assertEqual(len(actual_calls), 1)
249        self.assertFalse(actual_calls[0].did_raise)
250        self.assertIsNone(actual_calls[0].returned_cert_config)
251
252        # now create the "persistent" clients
253        self.cert_config_fetcher.reset()
254        self.cert_config_fetcher.configure(False, None)
255        channel_A = _create_channel(
256            self.port,
257            grpc.ssl_channel_credentials(
258                root_certificates=CA_1_PEM,
259                private_key=CLIENT_KEY_2_PEM,
260                certificate_chain=CLIENT_CERT_CHAIN_2_PEM))
261        persistent_client_stub_A = _create_client_stub(channel_A, True)
262        self._perform_rpc(persistent_client_stub_A, True)
263        actual_calls = self.cert_config_fetcher.getCalls()
264        self.assertEqual(len(actual_calls), 1)
265        self.assertFalse(actual_calls[0].did_raise)
266        self.assertIsNone(actual_calls[0].returned_cert_config)
267
268        self.cert_config_fetcher.reset()
269        self.cert_config_fetcher.configure(False, None)
270        channel_B = _create_channel(
271            self.port,
272            grpc.ssl_channel_credentials(
273                root_certificates=CA_1_PEM,
274                private_key=CLIENT_KEY_2_PEM,
275                certificate_chain=CLIENT_CERT_CHAIN_2_PEM))
276        persistent_client_stub_B = _create_client_stub(channel_B, True)
277        self._perform_rpc(persistent_client_stub_B, True)
278        actual_calls = self.cert_config_fetcher.getCalls()
279        self.assertEqual(len(actual_calls), 1)
280        self.assertFalse(actual_calls[0].did_raise)
281        self.assertIsNone(actual_calls[0].returned_cert_config)
282
283        # moment of truth!! client should reject server because the
284        # server switch cert...
285        cert_config = grpc.ssl_server_certificate_configuration(
286            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
287            root_certificates=CA_1_PEM)
288        self.cert_config_fetcher.reset()
289        self.cert_config_fetcher.configure(False, cert_config)
290        self._do_one_shot_client_rpc(False,
291                                     root_certificates=CA_1_PEM,
292                                     private_key=CLIENT_KEY_2_PEM,
293                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
294        actual_calls = self.cert_config_fetcher.getCalls()
295        self.assertGreaterEqual(len(actual_calls), 1)
296        self.assertFalse(actual_calls[0].did_raise)
297        for i, call in enumerate(actual_calls):
298            self.assertFalse(call.did_raise, 'i= {}'.format(i))
299            self.assertEqual(call.returned_cert_config, cert_config,
300                             'i= {}'.format(i))
301
302        # now should work again...
303        self.cert_config_fetcher.reset()
304        self.cert_config_fetcher.configure(False, None)
305        self._do_one_shot_client_rpc(True,
306                                     root_certificates=CA_2_PEM,
307                                     private_key=CLIENT_KEY_1_PEM,
308                                     certificate_chain=CLIENT_CERT_CHAIN_1_PEM)
309        actual_calls = self.cert_config_fetcher.getCalls()
310        self.assertEqual(len(actual_calls), 1)
311        self.assertFalse(actual_calls[0].did_raise)
312        self.assertIsNone(actual_calls[0].returned_cert_config)
313
314        # client should be rejected by server if with_client_auth
315        self.cert_config_fetcher.reset()
316        self.cert_config_fetcher.configure(False, None)
317        self._do_one_shot_client_rpc(not self.require_client_auth(),
318                                     root_certificates=CA_2_PEM,
319                                     private_key=CLIENT_KEY_2_PEM,
320                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
321        actual_calls = self.cert_config_fetcher.getCalls()
322        self.assertGreaterEqual(len(actual_calls), 1)
323        for i, call in enumerate(actual_calls):
324            self.assertFalse(call.did_raise, 'i= {}'.format(i))
325            self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i))
326
327        # here client should reject server...
328        self.cert_config_fetcher.reset()
329        self.cert_config_fetcher.configure(False, None)
330        self._do_one_shot_client_rpc(False,
331                                     root_certificates=CA_1_PEM,
332                                     private_key=CLIENT_KEY_2_PEM,
333                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
334        actual_calls = self.cert_config_fetcher.getCalls()
335        self.assertGreaterEqual(len(actual_calls), 1)
336        for i, call in enumerate(actual_calls):
337            self.assertFalse(call.did_raise, 'i= {}'.format(i))
338            self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i))
339
340        # persistent clients should continue to work
341        self.cert_config_fetcher.reset()
342        self.cert_config_fetcher.configure(False, None)
343        self._perform_rpc(persistent_client_stub_A, True)
344        actual_calls = self.cert_config_fetcher.getCalls()
345        self.assertEqual(len(actual_calls), 0)
346
347        self.cert_config_fetcher.reset()
348        self.cert_config_fetcher.configure(False, None)
349        self._perform_rpc(persistent_client_stub_B, True)
350        actual_calls = self.cert_config_fetcher.getCalls()
351        self.assertEqual(len(actual_calls), 0)
352
353        channel_A.close()
354        channel_B.close()
355
356
357class ServerSSLCertConfigFetcherParamsChecks(unittest.TestCase):
358
359    def test_check_on_initial_config(self):
360        with self.assertRaises(TypeError):
361            grpc.dynamic_ssl_server_credentials(None, str)
362        with self.assertRaises(TypeError):
363            grpc.dynamic_ssl_server_credentials(1, str)
364
365    def test_check_on_config_fetcher(self):
366        cert_config = grpc.ssl_server_certificate_configuration(
367            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
368            root_certificates=CA_1_PEM)
369        with self.assertRaises(TypeError):
370            grpc.dynamic_ssl_server_credentials(cert_config, None)
371        with self.assertRaises(TypeError):
372            grpc.dynamic_ssl_server_credentials(cert_config, 1)
373
374
375class ServerSSLCertReloadTestWithClientAuth(_ServerSSLCertReloadTest):
376
377    def require_client_auth(self):
378        return True
379
380    test = _ServerSSLCertReloadTest._test
381
382
383class ServerSSLCertReloadTestWithoutClientAuth(_ServerSSLCertReloadTest):
384
385    def require_client_auth(self):
386        return False
387
388    test = _ServerSSLCertReloadTest._test
389
390
391class ServerSSLCertReloadTestCertConfigReuse(_ServerSSLCertReloadTest):
392    """Ensures that `ServerCertificateConfiguration` instances can be reused.
393
394    Because gRPC Core takes ownership of the
395    `grpc_ssl_server_certificate_config` encapsulated by
396    `ServerCertificateConfiguration`, this test reuses the same
397    `ServerCertificateConfiguration` instances multiple times to make sure
398    gRPC Python takes care of maintaining the validity of
399    `ServerCertificateConfiguration` instances, so that such instances can be
400    re-used by user application.
401    """
402
403    def require_client_auth(self):
404        return True
405
406    def setUp(self):
407        self.server = test_common.test_server()
408        services_pb2_grpc.add_FirstServiceServicer_to_server(
409            _server_application.FirstServiceServicer(), self.server)
410        self.cert_config_A = grpc.ssl_server_certificate_configuration(
411            [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)],
412            root_certificates=CA_2_PEM)
413        self.cert_config_B = grpc.ssl_server_certificate_configuration(
414            [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)],
415            root_certificates=CA_1_PEM)
416        self.cert_config_fetcher = CertConfigFetcher()
417        server_credentials = grpc.dynamic_ssl_server_credentials(
418            self.cert_config_A,
419            self.cert_config_fetcher,
420            require_client_authentication=True)
421        self.port = self.server.add_secure_port('[::]:0', server_credentials)
422        self.server.start()
423
424    def test_cert_config_reuse(self):
425
426        # succeed with A
427        self.cert_config_fetcher.reset()
428        self.cert_config_fetcher.configure(False, self.cert_config_A)
429        self._do_one_shot_client_rpc(True,
430                                     root_certificates=CA_1_PEM,
431                                     private_key=CLIENT_KEY_2_PEM,
432                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
433        actual_calls = self.cert_config_fetcher.getCalls()
434        self.assertEqual(len(actual_calls), 1)
435        self.assertFalse(actual_calls[0].did_raise)
436        self.assertEqual(actual_calls[0].returned_cert_config,
437                         self.cert_config_A)
438
439        # fail with A
440        self.cert_config_fetcher.reset()
441        self.cert_config_fetcher.configure(False, self.cert_config_A)
442        self._do_one_shot_client_rpc(False,
443                                     root_certificates=CA_2_PEM,
444                                     private_key=CLIENT_KEY_1_PEM,
445                                     certificate_chain=CLIENT_CERT_CHAIN_1_PEM)
446        actual_calls = self.cert_config_fetcher.getCalls()
447        self.assertGreaterEqual(len(actual_calls), 1)
448        self.assertFalse(actual_calls[0].did_raise)
449        for i, call in enumerate(actual_calls):
450            self.assertFalse(call.did_raise, 'i= {}'.format(i))
451            self.assertEqual(call.returned_cert_config, self.cert_config_A,
452                             'i= {}'.format(i))
453
454        # succeed again with A
455        self.cert_config_fetcher.reset()
456        self.cert_config_fetcher.configure(False, self.cert_config_A)
457        self._do_one_shot_client_rpc(True,
458                                     root_certificates=CA_1_PEM,
459                                     private_key=CLIENT_KEY_2_PEM,
460                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
461        actual_calls = self.cert_config_fetcher.getCalls()
462        self.assertEqual(len(actual_calls), 1)
463        self.assertFalse(actual_calls[0].did_raise)
464        self.assertEqual(actual_calls[0].returned_cert_config,
465                         self.cert_config_A)
466
467        # succeed with B
468        self.cert_config_fetcher.reset()
469        self.cert_config_fetcher.configure(False, self.cert_config_B)
470        self._do_one_shot_client_rpc(True,
471                                     root_certificates=CA_2_PEM,
472                                     private_key=CLIENT_KEY_1_PEM,
473                                     certificate_chain=CLIENT_CERT_CHAIN_1_PEM)
474        actual_calls = self.cert_config_fetcher.getCalls()
475        self.assertEqual(len(actual_calls), 1)
476        self.assertFalse(actual_calls[0].did_raise)
477        self.assertEqual(actual_calls[0].returned_cert_config,
478                         self.cert_config_B)
479
480        # fail with B
481        self.cert_config_fetcher.reset()
482        self.cert_config_fetcher.configure(False, self.cert_config_B)
483        self._do_one_shot_client_rpc(False,
484                                     root_certificates=CA_1_PEM,
485                                     private_key=CLIENT_KEY_2_PEM,
486                                     certificate_chain=CLIENT_CERT_CHAIN_2_PEM)
487        actual_calls = self.cert_config_fetcher.getCalls()
488        self.assertGreaterEqual(len(actual_calls), 1)
489        self.assertFalse(actual_calls[0].did_raise)
490        for i, call in enumerate(actual_calls):
491            self.assertFalse(call.did_raise, 'i= {}'.format(i))
492            self.assertEqual(call.returned_cert_config, self.cert_config_B,
493                             'i= {}'.format(i))
494
495        # succeed again with B
496        self.cert_config_fetcher.reset()
497        self.cert_config_fetcher.configure(False, self.cert_config_B)
498        self._do_one_shot_client_rpc(True,
499                                     root_certificates=CA_2_PEM,
500                                     private_key=CLIENT_KEY_1_PEM,
501                                     certificate_chain=CLIENT_CERT_CHAIN_1_PEM)
502        actual_calls = self.cert_config_fetcher.getCalls()
503        self.assertEqual(len(actual_calls), 1)
504        self.assertFalse(actual_calls[0].did_raise)
505        self.assertEqual(actual_calls[0].returned_cert_config,
506                         self.cert_config_B)
507
508
509if __name__ == '__main__':
510    logging.basicConfig()
511    unittest.main(verbosity=2)
512