• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import datetime
15import enum
16import hashlib
17import logging
18import time
19from typing import Optional, Tuple
20
21from absl import flags
22from absl.testing import absltest
23
24from framework import xds_flags
25from framework import xds_k8s_flags
26from framework.helpers import retryers
27from framework.infrastructure import gcp
28from framework.infrastructure import k8s
29from framework.infrastructure import traffic_director
30from framework.rpc import grpc_channelz
31from framework.rpc import grpc_testing
32from framework.test_app import client_app
33from framework.test_app import server_app
34
35logger = logging.getLogger(__name__)
36_FORCE_CLEANUP = flags.DEFINE_bool(
37    "force_cleanup",
38    default=False,
39    help="Force resource cleanup, even if not created by this test run")
40# TODO(yashkt): We will no longer need this flag once Core exposes local certs
41# from channelz
42_CHECK_LOCAL_CERTS = flags.DEFINE_bool(
43    "check_local_certs",
44    default=True,
45    help="Security Tests also check the value of local certs")
46flags.adopt_module_key_flags(xds_flags)
47flags.adopt_module_key_flags(xds_k8s_flags)
48
49# Type aliases
50XdsTestServer = server_app.XdsTestServer
51XdsTestClient = client_app.XdsTestClient
52LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
53_ChannelState = grpc_channelz.ChannelState
54_timedelta = datetime.timedelta
55_DEFAULT_SECURE_MODE_MAINTENANCE_PORT = \
56    server_app.KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT
57
58
59class XdsKubernetesTestCase(absltest.TestCase):
60    k8s_api_manager: k8s.KubernetesApiManager
61    gcp_api_manager: gcp.api.GcpApiManager
62
63    @classmethod
64    def setUpClass(cls):
65        # GCP
66        cls.project: str = xds_flags.PROJECT.value
67        cls.network: str = xds_flags.NETWORK.value
68        cls.gcp_service_account: str = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value
69        cls.td_bootstrap_image = xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value
70        cls.xds_server_uri = xds_flags.XDS_SERVER_URI.value
71
72        # Base namespace
73        # TODO(sergiitk): generate for each test
74        cls.namespace: str = xds_flags.NAMESPACE.value
75
76        # Test server
77        cls.server_image = xds_k8s_flags.SERVER_IMAGE.value
78        cls.server_name = xds_flags.SERVER_NAME.value
79        cls.server_port = xds_flags.SERVER_PORT.value
80        cls.server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value
81        cls.server_xds_host = xds_flags.SERVER_NAME.value
82        cls.server_xds_port = xds_flags.SERVER_XDS_PORT.value
83
84        # Test client
85        cls.client_image = xds_k8s_flags.CLIENT_IMAGE.value
86        cls.client_name = xds_flags.CLIENT_NAME.value
87        cls.client_port = xds_flags.CLIENT_PORT.value
88
89        # Test suite settings
90        cls.force_cleanup = _FORCE_CLEANUP.value
91        cls.debug_use_port_forwarding = \
92            xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
93        cls.check_local_certs = _CHECK_LOCAL_CERTS.value
94
95        # Resource managers
96        cls.k8s_api_manager = k8s.KubernetesApiManager(
97            xds_k8s_flags.KUBE_CONTEXT.value)
98        cls.gcp_api_manager = gcp.api.GcpApiManager()
99
100    def setUp(self):
101        # TODO(sergiitk): generate namespace with run id for each test
102        self.server_namespace = self.namespace
103        self.client_namespace = self.namespace
104
105        # Init this in child class
106        # TODO(sergiitk): consider making a method to be less error-prone
107        self.server_runner = None
108        self.client_runner = None
109        self.td = None
110
111    @classmethod
112    def tearDownClass(cls):
113        cls.k8s_api_manager.close()
114        cls.gcp_api_manager.close()
115
116    def tearDown(self):
117        logger.info('----- TestMethod %s teardown -----', self.id())
118        retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10),
119                                            attempts=3,
120                                            log_level=logging.INFO)
121        try:
122            retryer(self._cleanup)
123        except retryers.RetryError:
124            logger.exception('Got error during teardown')
125
126    def _cleanup(self):
127        self.td.cleanup(force=self.force_cleanup)
128        self.client_runner.cleanup(force=self.force_cleanup)
129        self.server_runner.cleanup(force=self.force_cleanup,
130                                   force_namespace=self.force_cleanup)
131
132    def setupTrafficDirectorGrpc(self):
133        self.td.setup_for_grpc(self.server_xds_host,
134                               self.server_xds_port,
135                               health_check_port=self.server_maintenance_port)
136
137    def setupServerBackends(self, *, wait_for_healthy_status=True):
138        # Load Backends
139        neg_name, neg_zones = self.server_runner.k8s_namespace.get_service_neg(
140            self.server_runner.service_name, self.server_port)
141
142        # Add backends to the Backend Service
143        self.td.backend_service_add_neg_backends(neg_name, neg_zones)
144        if wait_for_healthy_status:
145            self.td.wait_for_backends_healthy_status()
146
147    def assertSuccessfulRpcs(self,
148                             test_client: XdsTestClient,
149                             num_rpcs: int = 100):
150        lb_stats = self.sendRpcs(test_client, num_rpcs)
151        self.assertAllBackendsReceivedRpcs(lb_stats)
152        failed = int(lb_stats.num_failures)
153        self.assertLessEqual(
154            failed,
155            0,
156            msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed')
157
158    def assertFailedRpcs(self,
159                         test_client: XdsTestClient,
160                         num_rpcs: Optional[int] = 100):
161        lb_stats = self.sendRpcs(test_client, num_rpcs)
162        failed = int(lb_stats.num_failures)
163        self.assertEqual(
164            failed,
165            num_rpcs,
166            msg=f'Expected all RPCs to fail: {failed} of {num_rpcs} failed')
167
168    @staticmethod
169    def sendRpcs(test_client: XdsTestClient,
170                 num_rpcs: int) -> LoadBalancerStatsResponse:
171        lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
172        logger.info(
173            'Received LoadBalancerStatsResponse from test client %s:\n%s',
174            test_client.ip, lb_stats)
175        return lb_stats
176
177    def assertAllBackendsReceivedRpcs(self, lb_stats):
178        # TODO(sergiitk): assert backends length
179        for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
180            self.assertGreater(
181                int(rpcs_count),
182                0,
183                msg=f'Backend {backend} did not receive a single RPC')
184
185
186class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
187
188    def setUp(self):
189        super().setUp()
190
191        # Traffic Director Configuration
192        self.td = traffic_director.TrafficDirectorManager(
193            self.gcp_api_manager,
194            project=self.project,
195            resource_prefix=self.namespace,
196            network=self.network)
197
198        # Test Server Runner
199        self.server_runner = server_app.KubernetesServerRunner(
200            k8s.KubernetesNamespace(self.k8s_api_manager,
201                                    self.server_namespace),
202            deployment_name=self.server_name,
203            image_name=self.server_image,
204            gcp_service_account=self.gcp_service_account,
205            td_bootstrap_image=self.td_bootstrap_image,
206            xds_server_uri=self.xds_server_uri,
207            network=self.network)
208
209        # Test Client Runner
210        self.client_runner = client_app.KubernetesClientRunner(
211            k8s.KubernetesNamespace(self.k8s_api_manager,
212                                    self.client_namespace),
213            deployment_name=self.client_name,
214            image_name=self.client_image,
215            gcp_service_account=self.gcp_service_account,
216            td_bootstrap_image=self.td_bootstrap_image,
217            xds_server_uri=self.xds_server_uri,
218            network=self.network,
219            debug_use_port_forwarding=self.debug_use_port_forwarding,
220            stats_port=self.client_port,
221            reuse_namespace=self.server_namespace == self.client_namespace)
222
223    def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
224        test_server = self.server_runner.run(
225            replica_count=replica_count,
226            test_port=self.server_port,
227            maintenance_port=self.server_maintenance_port,
228            **kwargs)
229        test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
230        return test_server
231
232    def startTestClient(self, test_server: XdsTestServer,
233                        **kwargs) -> XdsTestClient:
234        test_client = self.client_runner.run(server_target=test_server.xds_uri,
235                                             **kwargs)
236        test_client.wait_for_active_server_channel()
237        return test_client
238
239
240class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
241
242    class SecurityMode(enum.Enum):
243        MTLS = enum.auto()
244        TLS = enum.auto()
245        PLAINTEXT = enum.auto()
246
247    @classmethod
248    def setUpClass(cls):
249        super().setUpClass()
250        if cls.server_maintenance_port is None:
251            # In secure mode, the maintenance port is different from
252            # the test port to keep it insecure, and make
253            # Health Checks and Channelz tests available.
254            # When not provided, use explicit numeric port value, so
255            # Backend Health Checks are created on a fixed port.
256            cls.server_maintenance_port = _DEFAULT_SECURE_MODE_MAINTENANCE_PORT
257
258    def setUp(self):
259        super().setUp()
260
261        # Traffic Director Configuration
262        self.td = traffic_director.TrafficDirectorSecureManager(
263            self.gcp_api_manager,
264            project=self.project,
265            resource_prefix=self.namespace,
266            network=self.network)
267
268        # Test Server Runner
269        self.server_runner = server_app.KubernetesServerRunner(
270            k8s.KubernetesNamespace(self.k8s_api_manager,
271                                    self.server_namespace),
272            deployment_name=self.server_name,
273            image_name=self.server_image,
274            gcp_service_account=self.gcp_service_account,
275            network=self.network,
276            td_bootstrap_image=self.td_bootstrap_image,
277            xds_server_uri=self.xds_server_uri,
278            deployment_template='server-secure.deployment.yaml',
279            debug_use_port_forwarding=self.debug_use_port_forwarding)
280
281        # Test Client Runner
282        self.client_runner = client_app.KubernetesClientRunner(
283            k8s.KubernetesNamespace(self.k8s_api_manager,
284                                    self.client_namespace),
285            deployment_name=self.client_name,
286            image_name=self.client_image,
287            gcp_service_account=self.gcp_service_account,
288            td_bootstrap_image=self.td_bootstrap_image,
289            xds_server_uri=self.xds_server_uri,
290            network=self.network,
291            deployment_template='client-secure.deployment.yaml',
292            stats_port=self.client_port,
293            reuse_namespace=self.server_namespace == self.client_namespace,
294            debug_use_port_forwarding=self.debug_use_port_forwarding)
295
296    def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
297        test_server = self.server_runner.run(
298            replica_count=replica_count,
299            test_port=self.server_port,
300            maintenance_port=self.server_maintenance_port,
301            secure_mode=True,
302            **kwargs)
303        test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
304        return test_server
305
306    def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls,
307                              client_mtls):
308        self.td.setup_client_security(server_namespace=self.server_namespace,
309                                      server_name=self.server_name,
310                                      tls=client_tls,
311                                      mtls=client_mtls)
312        self.td.setup_server_security(server_namespace=self.server_namespace,
313                                      server_name=self.server_name,
314                                      server_port=self.server_port,
315                                      tls=server_tls,
316                                      mtls=server_mtls)
317
318    def startSecureTestClient(self,
319                              test_server: XdsTestServer,
320                              *,
321                              wait_for_active_server_channel=True,
322                              **kwargs) -> XdsTestClient:
323        test_client = self.client_runner.run(server_target=test_server.xds_uri,
324                                             secure_mode=True,
325                                             **kwargs)
326        if wait_for_active_server_channel:
327            test_client.wait_for_active_server_channel()
328        return test_client
329
330    def assertTestAppSecurity(self, mode: SecurityMode,
331                              test_client: XdsTestClient,
332                              test_server: XdsTestServer):
333        client_socket, server_socket = self.getConnectedSockets(
334            test_client, test_server)
335        server_security: grpc_channelz.Security = server_socket.security
336        client_security: grpc_channelz.Security = client_socket.security
337        logger.info('Server certs: %s', self.debug_sock_certs(server_security))
338        logger.info('Client certs: %s', self.debug_sock_certs(client_security))
339
340        if mode is self.SecurityMode.MTLS:
341            self.assertSecurityMtls(client_security, server_security)
342        elif mode is self.SecurityMode.TLS:
343            self.assertSecurityTls(client_security, server_security)
344        elif mode is self.SecurityMode.PLAINTEXT:
345            self.assertSecurityPlaintext(client_security, server_security)
346        else:
347            raise TypeError('Incorrect security mode')
348
349    def assertSecurityMtls(self, client_security: grpc_channelz.Security,
350                           server_security: grpc_channelz.Security):
351        self.assertEqual(client_security.WhichOneof('model'),
352                         'tls',
353                         msg='(mTLS) Client socket security model must be TLS')
354        self.assertEqual(server_security.WhichOneof('model'),
355                         'tls',
356                         msg='(mTLS) Server socket security model must be TLS')
357        server_tls, client_tls = server_security.tls, client_security.tls
358
359        # Confirm regular TLS: server local cert == client remote cert
360        self.assertNotEmpty(client_tls.remote_certificate,
361                            msg="(mTLS) Client remote certificate is missing")
362        if self.check_local_certs:
363            self.assertNotEmpty(
364                server_tls.local_certificate,
365                msg="(mTLS) Server local certificate is missing")
366            self.assertEqual(
367                server_tls.local_certificate,
368                client_tls.remote_certificate,
369                msg="(mTLS) Server local certificate must match client's "
370                "remote certificate")
371
372        # mTLS: server remote cert == client local cert
373        self.assertNotEmpty(server_tls.remote_certificate,
374                            msg="(mTLS) Server remote certificate is missing")
375        if self.check_local_certs:
376            self.assertNotEmpty(
377                client_tls.local_certificate,
378                msg="(mTLS) Client local certificate is missing")
379            self.assertEqual(
380                server_tls.remote_certificate,
381                client_tls.local_certificate,
382                msg="(mTLS) Server remote certificate must match client's "
383                "local certificate")
384
385    def assertSecurityTls(self, client_security: grpc_channelz.Security,
386                          server_security: grpc_channelz.Security):
387        self.assertEqual(client_security.WhichOneof('model'),
388                         'tls',
389                         msg='(TLS) Client socket security model must be TLS')
390        self.assertEqual(server_security.WhichOneof('model'),
391                         'tls',
392                         msg='(TLS) Server socket security model must be TLS')
393        server_tls, client_tls = server_security.tls, client_security.tls
394
395        # Regular TLS: server local cert == client remote cert
396        self.assertNotEmpty(client_tls.remote_certificate,
397                            msg="(TLS) Client remote certificate is missing")
398        if self.check_local_certs:
399            self.assertNotEmpty(server_tls.local_certificate,
400                                msg="(TLS) Server local certificate is missing")
401            self.assertEqual(
402                server_tls.local_certificate,
403                client_tls.remote_certificate,
404                msg="(TLS) Server local certificate must match client "
405                "remote certificate")
406
407        # mTLS must not be used
408        self.assertEmpty(
409            server_tls.remote_certificate,
410            msg="(TLS) Server remote certificate must be empty in TLS mode. "
411            "Is server security incorrectly configured for mTLS?")
412        self.assertEmpty(
413            client_tls.local_certificate,
414            msg="(TLS) Client local certificate must be empty in TLS mode. "
415            "Is client security incorrectly configured for mTLS?")
416
417    def assertSecurityPlaintext(self, client_security, server_security):
418        server_tls, client_tls = server_security.tls, client_security.tls
419        # Not TLS
420        self.assertEmpty(
421            server_tls.local_certificate,
422            msg="(Plaintext) Server local certificate must be empty.")
423        self.assertEmpty(
424            client_tls.local_certificate,
425            msg="(Plaintext) Client local certificate must be empty.")
426
427        # Not mTLS
428        self.assertEmpty(
429            server_tls.remote_certificate,
430            msg="(Plaintext) Server remote certificate must be empty.")
431        self.assertEmpty(
432            client_tls.local_certificate,
433            msg="(Plaintext) Client local certificate must be empty.")
434
435    def assertClientCannotReachServerRepeatedly(
436            self,
437            test_client: XdsTestClient,
438            *,
439            times: Optional[int] = None,
440            delay: Optional[_timedelta] = None):
441        """
442        Asserts that the client repeatedly cannot reach the server.
443
444        With negative tests we can't be absolutely certain expected failure
445        state is not caused by something else.
446        To mitigate for this, we repeat the checks several times, and expect
447        all of them to succeed.
448
449        This is useful in case the channel eventually stabilizes, and RPCs pass.
450
451        Args:
452            test_client: An instance of XdsTestClient
453            times: Optional; A positive number of times to confirm that
454                the server is unreachable. Defaults to `3` attempts.
455            delay: Optional; Specifies how long to wait before the next check.
456                Defaults to `10` seconds.
457        """
458        if times is None or times < 1:
459            times = 3
460        if delay is None:
461            delay = _timedelta(seconds=10)
462
463        for i in range(1, times + 1):
464            self.assertClientCannotReachServer(test_client)
465            if i < times:
466                logger.info('Check %s passed, waiting %s before the next check',
467                            i, delay)
468                time.sleep(delay.total_seconds())
469
470    def assertClientCannotReachServer(self, test_client: XdsTestClient):
471        self.assertClientChannelFailed(test_client)
472        self.assertFailedRpcs(test_client)
473
474    def assertClientChannelFailed(self, test_client: XdsTestClient):
475        channel = test_client.wait_for_server_channel_state(
476            state=_ChannelState.TRANSIENT_FAILURE)
477        subchannels = list(
478            test_client.channelz.list_channel_subchannels(channel))
479        self.assertLen(subchannels,
480                       1,
481                       msg="Client channel must have exactly one subchannel "
482                       "in state TRANSIENT_FAILURE.")
483
484    @staticmethod
485    def getConnectedSockets(
486        test_client: XdsTestClient, test_server: XdsTestServer
487    ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
488        client_sock = test_client.get_active_server_channel_socket()
489        server_sock = test_server.get_server_socket_matching_client(client_sock)
490        return client_sock, server_sock
491
492    @classmethod
493    def debug_sock_certs(cls, security: grpc_channelz.Security):
494        if security.WhichOneof('model') == 'other':
495            return f'other: <{security.other.name}={security.other.value}>'
496
497        return (f'local: <{cls.debug_cert(security.tls.local_certificate)}>, '
498                f'remote: <{cls.debug_cert(security.tls.remote_certificate)}>')
499
500    @staticmethod
501    def debug_cert(cert):
502        if not cert:
503            return 'missing'
504        sha1 = hashlib.sha1(cert)
505        return f'sha1={sha1.hexdigest()}, len={len(cert)}'
506