1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import datetime 15import enum 16import hashlib 17import logging 18import time 19from typing import Optional, Tuple 20 21from absl import flags 22from absl.testing import absltest 23 24from framework import xds_flags 25from framework import xds_k8s_flags 26from framework.helpers import retryers 27from framework.infrastructure import gcp 28from framework.infrastructure import k8s 29from framework.infrastructure import traffic_director 30from framework.rpc import grpc_channelz 31from framework.rpc import grpc_testing 32from framework.test_app import client_app 33from framework.test_app import server_app 34 35logger = logging.getLogger(__name__) 36_FORCE_CLEANUP = flags.DEFINE_bool( 37 "force_cleanup", 38 default=False, 39 help="Force resource cleanup, even if not created by this test run") 40# TODO(yashkt): We will no longer need this flag once Core exposes local certs 41# from channelz 42_CHECK_LOCAL_CERTS = flags.DEFINE_bool( 43 "check_local_certs", 44 default=True, 45 help="Security Tests also check the value of local certs") 46flags.adopt_module_key_flags(xds_flags) 47flags.adopt_module_key_flags(xds_k8s_flags) 48 49# Type aliases 50XdsTestServer = server_app.XdsTestServer 51XdsTestClient = client_app.XdsTestClient 52LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse 53_ChannelState = grpc_channelz.ChannelState 54_timedelta = datetime.timedelta 55_DEFAULT_SECURE_MODE_MAINTENANCE_PORT = \ 56 server_app.KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT 57 58 59class XdsKubernetesTestCase(absltest.TestCase): 60 k8s_api_manager: k8s.KubernetesApiManager 61 gcp_api_manager: gcp.api.GcpApiManager 62 63 @classmethod 64 def setUpClass(cls): 65 # GCP 66 cls.project: str = xds_flags.PROJECT.value 67 cls.network: str = xds_flags.NETWORK.value 68 cls.gcp_service_account: str = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value 69 cls.td_bootstrap_image = xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value 70 cls.xds_server_uri = xds_flags.XDS_SERVER_URI.value 71 72 # Base namespace 73 # TODO(sergiitk): generate for each test 74 cls.namespace: str = xds_flags.NAMESPACE.value 75 76 # Test server 77 cls.server_image = xds_k8s_flags.SERVER_IMAGE.value 78 cls.server_name = xds_flags.SERVER_NAME.value 79 cls.server_port = xds_flags.SERVER_PORT.value 80 cls.server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value 81 cls.server_xds_host = xds_flags.SERVER_NAME.value 82 cls.server_xds_port = xds_flags.SERVER_XDS_PORT.value 83 84 # Test client 85 cls.client_image = xds_k8s_flags.CLIENT_IMAGE.value 86 cls.client_name = xds_flags.CLIENT_NAME.value 87 cls.client_port = xds_flags.CLIENT_PORT.value 88 89 # Test suite settings 90 cls.force_cleanup = _FORCE_CLEANUP.value 91 cls.debug_use_port_forwarding = \ 92 xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value 93 cls.check_local_certs = _CHECK_LOCAL_CERTS.value 94 95 # Resource managers 96 cls.k8s_api_manager = k8s.KubernetesApiManager( 97 xds_k8s_flags.KUBE_CONTEXT.value) 98 cls.gcp_api_manager = gcp.api.GcpApiManager() 99 100 def setUp(self): 101 # TODO(sergiitk): generate namespace with run id for each test 102 self.server_namespace = self.namespace 103 self.client_namespace = self.namespace 104 105 # Init this in child class 106 # TODO(sergiitk): consider making a method to be less error-prone 107 self.server_runner = None 108 self.client_runner = None 109 self.td = None 110 111 @classmethod 112 def tearDownClass(cls): 113 cls.k8s_api_manager.close() 114 cls.gcp_api_manager.close() 115 116 def tearDown(self): 117 logger.info('----- TestMethod %s teardown -----', self.id()) 118 retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), 119 attempts=3, 120 log_level=logging.INFO) 121 try: 122 retryer(self._cleanup) 123 except retryers.RetryError: 124 logger.exception('Got error during teardown') 125 126 def _cleanup(self): 127 self.td.cleanup(force=self.force_cleanup) 128 self.client_runner.cleanup(force=self.force_cleanup) 129 self.server_runner.cleanup(force=self.force_cleanup, 130 force_namespace=self.force_cleanup) 131 132 def setupTrafficDirectorGrpc(self): 133 self.td.setup_for_grpc(self.server_xds_host, 134 self.server_xds_port, 135 health_check_port=self.server_maintenance_port) 136 137 def setupServerBackends(self, *, wait_for_healthy_status=True): 138 # Load Backends 139 neg_name, neg_zones = self.server_runner.k8s_namespace.get_service_neg( 140 self.server_runner.service_name, self.server_port) 141 142 # Add backends to the Backend Service 143 self.td.backend_service_add_neg_backends(neg_name, neg_zones) 144 if wait_for_healthy_status: 145 self.td.wait_for_backends_healthy_status() 146 147 def assertSuccessfulRpcs(self, 148 test_client: XdsTestClient, 149 num_rpcs: int = 100): 150 lb_stats = self.sendRpcs(test_client, num_rpcs) 151 self.assertAllBackendsReceivedRpcs(lb_stats) 152 failed = int(lb_stats.num_failures) 153 self.assertLessEqual( 154 failed, 155 0, 156 msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') 157 158 def assertFailedRpcs(self, 159 test_client: XdsTestClient, 160 num_rpcs: Optional[int] = 100): 161 lb_stats = self.sendRpcs(test_client, num_rpcs) 162 failed = int(lb_stats.num_failures) 163 self.assertEqual( 164 failed, 165 num_rpcs, 166 msg=f'Expected all RPCs to fail: {failed} of {num_rpcs} failed') 167 168 @staticmethod 169 def sendRpcs(test_client: XdsTestClient, 170 num_rpcs: int) -> LoadBalancerStatsResponse: 171 lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) 172 logger.info( 173 'Received LoadBalancerStatsResponse from test client %s:\n%s', 174 test_client.ip, lb_stats) 175 return lb_stats 176 177 def assertAllBackendsReceivedRpcs(self, lb_stats): 178 # TODO(sergiitk): assert backends length 179 for backend, rpcs_count in lb_stats.rpcs_by_peer.items(): 180 self.assertGreater( 181 int(rpcs_count), 182 0, 183 msg=f'Backend {backend} did not receive a single RPC') 184 185 186class RegularXdsKubernetesTestCase(XdsKubernetesTestCase): 187 188 def setUp(self): 189 super().setUp() 190 191 # Traffic Director Configuration 192 self.td = traffic_director.TrafficDirectorManager( 193 self.gcp_api_manager, 194 project=self.project, 195 resource_prefix=self.namespace, 196 network=self.network) 197 198 # Test Server Runner 199 self.server_runner = server_app.KubernetesServerRunner( 200 k8s.KubernetesNamespace(self.k8s_api_manager, 201 self.server_namespace), 202 deployment_name=self.server_name, 203 image_name=self.server_image, 204 gcp_service_account=self.gcp_service_account, 205 td_bootstrap_image=self.td_bootstrap_image, 206 xds_server_uri=self.xds_server_uri, 207 network=self.network) 208 209 # Test Client Runner 210 self.client_runner = client_app.KubernetesClientRunner( 211 k8s.KubernetesNamespace(self.k8s_api_manager, 212 self.client_namespace), 213 deployment_name=self.client_name, 214 image_name=self.client_image, 215 gcp_service_account=self.gcp_service_account, 216 td_bootstrap_image=self.td_bootstrap_image, 217 xds_server_uri=self.xds_server_uri, 218 network=self.network, 219 debug_use_port_forwarding=self.debug_use_port_forwarding, 220 stats_port=self.client_port, 221 reuse_namespace=self.server_namespace == self.client_namespace) 222 223 def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: 224 test_server = self.server_runner.run( 225 replica_count=replica_count, 226 test_port=self.server_port, 227 maintenance_port=self.server_maintenance_port, 228 **kwargs) 229 test_server.set_xds_address(self.server_xds_host, self.server_xds_port) 230 return test_server 231 232 def startTestClient(self, test_server: XdsTestServer, 233 **kwargs) -> XdsTestClient: 234 test_client = self.client_runner.run(server_target=test_server.xds_uri, 235 **kwargs) 236 test_client.wait_for_active_server_channel() 237 return test_client 238 239 240class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase): 241 242 class SecurityMode(enum.Enum): 243 MTLS = enum.auto() 244 TLS = enum.auto() 245 PLAINTEXT = enum.auto() 246 247 @classmethod 248 def setUpClass(cls): 249 super().setUpClass() 250 if cls.server_maintenance_port is None: 251 # In secure mode, the maintenance port is different from 252 # the test port to keep it insecure, and make 253 # Health Checks and Channelz tests available. 254 # When not provided, use explicit numeric port value, so 255 # Backend Health Checks are created on a fixed port. 256 cls.server_maintenance_port = _DEFAULT_SECURE_MODE_MAINTENANCE_PORT 257 258 def setUp(self): 259 super().setUp() 260 261 # Traffic Director Configuration 262 self.td = traffic_director.TrafficDirectorSecureManager( 263 self.gcp_api_manager, 264 project=self.project, 265 resource_prefix=self.namespace, 266 network=self.network) 267 268 # Test Server Runner 269 self.server_runner = server_app.KubernetesServerRunner( 270 k8s.KubernetesNamespace(self.k8s_api_manager, 271 self.server_namespace), 272 deployment_name=self.server_name, 273 image_name=self.server_image, 274 gcp_service_account=self.gcp_service_account, 275 network=self.network, 276 td_bootstrap_image=self.td_bootstrap_image, 277 xds_server_uri=self.xds_server_uri, 278 deployment_template='server-secure.deployment.yaml', 279 debug_use_port_forwarding=self.debug_use_port_forwarding) 280 281 # Test Client Runner 282 self.client_runner = client_app.KubernetesClientRunner( 283 k8s.KubernetesNamespace(self.k8s_api_manager, 284 self.client_namespace), 285 deployment_name=self.client_name, 286 image_name=self.client_image, 287 gcp_service_account=self.gcp_service_account, 288 td_bootstrap_image=self.td_bootstrap_image, 289 xds_server_uri=self.xds_server_uri, 290 network=self.network, 291 deployment_template='client-secure.deployment.yaml', 292 stats_port=self.client_port, 293 reuse_namespace=self.server_namespace == self.client_namespace, 294 debug_use_port_forwarding=self.debug_use_port_forwarding) 295 296 def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: 297 test_server = self.server_runner.run( 298 replica_count=replica_count, 299 test_port=self.server_port, 300 maintenance_port=self.server_maintenance_port, 301 secure_mode=True, 302 **kwargs) 303 test_server.set_xds_address(self.server_xds_host, self.server_xds_port) 304 return test_server 305 306 def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls, 307 client_mtls): 308 self.td.setup_client_security(server_namespace=self.server_namespace, 309 server_name=self.server_name, 310 tls=client_tls, 311 mtls=client_mtls) 312 self.td.setup_server_security(server_namespace=self.server_namespace, 313 server_name=self.server_name, 314 server_port=self.server_port, 315 tls=server_tls, 316 mtls=server_mtls) 317 318 def startSecureTestClient(self, 319 test_server: XdsTestServer, 320 *, 321 wait_for_active_server_channel=True, 322 **kwargs) -> XdsTestClient: 323 test_client = self.client_runner.run(server_target=test_server.xds_uri, 324 secure_mode=True, 325 **kwargs) 326 if wait_for_active_server_channel: 327 test_client.wait_for_active_server_channel() 328 return test_client 329 330 def assertTestAppSecurity(self, mode: SecurityMode, 331 test_client: XdsTestClient, 332 test_server: XdsTestServer): 333 client_socket, server_socket = self.getConnectedSockets( 334 test_client, test_server) 335 server_security: grpc_channelz.Security = server_socket.security 336 client_security: grpc_channelz.Security = client_socket.security 337 logger.info('Server certs: %s', self.debug_sock_certs(server_security)) 338 logger.info('Client certs: %s', self.debug_sock_certs(client_security)) 339 340 if mode is self.SecurityMode.MTLS: 341 self.assertSecurityMtls(client_security, server_security) 342 elif mode is self.SecurityMode.TLS: 343 self.assertSecurityTls(client_security, server_security) 344 elif mode is self.SecurityMode.PLAINTEXT: 345 self.assertSecurityPlaintext(client_security, server_security) 346 else: 347 raise TypeError('Incorrect security mode') 348 349 def assertSecurityMtls(self, client_security: grpc_channelz.Security, 350 server_security: grpc_channelz.Security): 351 self.assertEqual(client_security.WhichOneof('model'), 352 'tls', 353 msg='(mTLS) Client socket security model must be TLS') 354 self.assertEqual(server_security.WhichOneof('model'), 355 'tls', 356 msg='(mTLS) Server socket security model must be TLS') 357 server_tls, client_tls = server_security.tls, client_security.tls 358 359 # Confirm regular TLS: server local cert == client remote cert 360 self.assertNotEmpty(client_tls.remote_certificate, 361 msg="(mTLS) Client remote certificate is missing") 362 if self.check_local_certs: 363 self.assertNotEmpty( 364 server_tls.local_certificate, 365 msg="(mTLS) Server local certificate is missing") 366 self.assertEqual( 367 server_tls.local_certificate, 368 client_tls.remote_certificate, 369 msg="(mTLS) Server local certificate must match client's " 370 "remote certificate") 371 372 # mTLS: server remote cert == client local cert 373 self.assertNotEmpty(server_tls.remote_certificate, 374 msg="(mTLS) Server remote certificate is missing") 375 if self.check_local_certs: 376 self.assertNotEmpty( 377 client_tls.local_certificate, 378 msg="(mTLS) Client local certificate is missing") 379 self.assertEqual( 380 server_tls.remote_certificate, 381 client_tls.local_certificate, 382 msg="(mTLS) Server remote certificate must match client's " 383 "local certificate") 384 385 def assertSecurityTls(self, client_security: grpc_channelz.Security, 386 server_security: grpc_channelz.Security): 387 self.assertEqual(client_security.WhichOneof('model'), 388 'tls', 389 msg='(TLS) Client socket security model must be TLS') 390 self.assertEqual(server_security.WhichOneof('model'), 391 'tls', 392 msg='(TLS) Server socket security model must be TLS') 393 server_tls, client_tls = server_security.tls, client_security.tls 394 395 # Regular TLS: server local cert == client remote cert 396 self.assertNotEmpty(client_tls.remote_certificate, 397 msg="(TLS) Client remote certificate is missing") 398 if self.check_local_certs: 399 self.assertNotEmpty(server_tls.local_certificate, 400 msg="(TLS) Server local certificate is missing") 401 self.assertEqual( 402 server_tls.local_certificate, 403 client_tls.remote_certificate, 404 msg="(TLS) Server local certificate must match client " 405 "remote certificate") 406 407 # mTLS must not be used 408 self.assertEmpty( 409 server_tls.remote_certificate, 410 msg="(TLS) Server remote certificate must be empty in TLS mode. " 411 "Is server security incorrectly configured for mTLS?") 412 self.assertEmpty( 413 client_tls.local_certificate, 414 msg="(TLS) Client local certificate must be empty in TLS mode. " 415 "Is client security incorrectly configured for mTLS?") 416 417 def assertSecurityPlaintext(self, client_security, server_security): 418 server_tls, client_tls = server_security.tls, client_security.tls 419 # Not TLS 420 self.assertEmpty( 421 server_tls.local_certificate, 422 msg="(Plaintext) Server local certificate must be empty.") 423 self.assertEmpty( 424 client_tls.local_certificate, 425 msg="(Plaintext) Client local certificate must be empty.") 426 427 # Not mTLS 428 self.assertEmpty( 429 server_tls.remote_certificate, 430 msg="(Plaintext) Server remote certificate must be empty.") 431 self.assertEmpty( 432 client_tls.local_certificate, 433 msg="(Plaintext) Client local certificate must be empty.") 434 435 def assertClientCannotReachServerRepeatedly( 436 self, 437 test_client: XdsTestClient, 438 *, 439 times: Optional[int] = None, 440 delay: Optional[_timedelta] = None): 441 """ 442 Asserts that the client repeatedly cannot reach the server. 443 444 With negative tests we can't be absolutely certain expected failure 445 state is not caused by something else. 446 To mitigate for this, we repeat the checks several times, and expect 447 all of them to succeed. 448 449 This is useful in case the channel eventually stabilizes, and RPCs pass. 450 451 Args: 452 test_client: An instance of XdsTestClient 453 times: Optional; A positive number of times to confirm that 454 the server is unreachable. Defaults to `3` attempts. 455 delay: Optional; Specifies how long to wait before the next check. 456 Defaults to `10` seconds. 457 """ 458 if times is None or times < 1: 459 times = 3 460 if delay is None: 461 delay = _timedelta(seconds=10) 462 463 for i in range(1, times + 1): 464 self.assertClientCannotReachServer(test_client) 465 if i < times: 466 logger.info('Check %s passed, waiting %s before the next check', 467 i, delay) 468 time.sleep(delay.total_seconds()) 469 470 def assertClientCannotReachServer(self, test_client: XdsTestClient): 471 self.assertClientChannelFailed(test_client) 472 self.assertFailedRpcs(test_client) 473 474 def assertClientChannelFailed(self, test_client: XdsTestClient): 475 channel = test_client.wait_for_server_channel_state( 476 state=_ChannelState.TRANSIENT_FAILURE) 477 subchannels = list( 478 test_client.channelz.list_channel_subchannels(channel)) 479 self.assertLen(subchannels, 480 1, 481 msg="Client channel must have exactly one subchannel " 482 "in state TRANSIENT_FAILURE.") 483 484 @staticmethod 485 def getConnectedSockets( 486 test_client: XdsTestClient, test_server: XdsTestServer 487 ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]: 488 client_sock = test_client.get_active_server_channel_socket() 489 server_sock = test_server.get_server_socket_matching_client(client_sock) 490 return client_sock, server_sock 491 492 @classmethod 493 def debug_sock_certs(cls, security: grpc_channelz.Security): 494 if security.WhichOneof('model') == 'other': 495 return f'other: <{security.other.name}={security.other.value}>' 496 497 return (f'local: <{cls.debug_cert(security.tls.local_certificate)}>, ' 498 f'remote: <{cls.debug_cert(security.tls.remote_certificate)}>') 499 500 @staticmethod 501 def debug_cert(cert): 502 if not cert: 503 return 'missing' 504 sha1 = hashlib.sha1(cert) 505 return f'sha1={sha1.hexdigest()}, len={len(cert)}' 506