• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from typing import Optional, Set
16
17from framework.infrastructure import gcp
18
19logger = logging.getLogger(__name__)
20
21# Type aliases
22# Compute
23_ComputeV1 = gcp.compute.ComputeV1
24GcpResource = _ComputeV1.GcpResource
25HealthCheckProtocol = _ComputeV1.HealthCheckProtocol
26ZonalGcpResource = _ComputeV1.ZonalGcpResource
27BackendServiceProtocol = _ComputeV1.BackendServiceProtocol
28_BackendGRPC = BackendServiceProtocol.GRPC
29_HealthCheckGRPC = HealthCheckProtocol.GRPC
30
31# Network Security
32_NetworkSecurityV1Alpha1 = gcp.network_security.NetworkSecurityV1Alpha1
33ServerTlsPolicy = _NetworkSecurityV1Alpha1.ServerTlsPolicy
34ClientTlsPolicy = _NetworkSecurityV1Alpha1.ClientTlsPolicy
35
36# Network Services
37_NetworkServicesV1Alpha1 = gcp.network_services.NetworkServicesV1Alpha1
38EndpointConfigSelector = _NetworkServicesV1Alpha1.EndpointConfigSelector
39
40
41class TrafficDirectorManager:
42    compute: _ComputeV1
43    BACKEND_SERVICE_NAME = "backend-service"
44    HEALTH_CHECK_NAME = "health-check"
45    URL_MAP_NAME = "url-map"
46    URL_MAP_PATH_MATCHER_NAME = "path-matcher"
47    TARGET_PROXY_NAME = "target-proxy"
48    FORWARDING_RULE_NAME = "forwarding-rule"
49
50    def __init__(
51        self,
52        gcp_api_manager: gcp.api.GcpApiManager,
53        project: str,
54        *,
55        resource_prefix: str,
56        network: str = 'default',
57    ):
58        # API
59        self.compute = _ComputeV1(gcp_api_manager, project)
60
61        # Settings
62        self.project: str = project
63        self.network: str = network
64        self.resource_prefix: str = resource_prefix
65
66        # Managed resources
67        self.health_check: Optional[GcpResource] = None
68        self.backend_service: Optional[GcpResource] = None
69        # TODO(sergiitk): remove this flag once backend service resource loaded
70        self.backend_service_protocol: Optional[BackendServiceProtocol] = None
71        self.url_map: Optional[GcpResource] = None
72        self.target_proxy: Optional[GcpResource] = None
73        # TODO(sergiitk): remove this flag once target proxy resource loaded
74        self.target_proxy_is_http: bool = False
75        self.forwarding_rule: Optional[GcpResource] = None
76        self.backends: Set[ZonalGcpResource] = set()
77
78    @property
79    def network_url(self):
80        return f'global/networks/{self.network}'
81
82    def setup_for_grpc(
83            self,
84            service_host,
85            service_port,
86            *,
87            backend_protocol: Optional[BackendServiceProtocol] = _BackendGRPC,
88            health_check_port: Optional[int] = None):
89        self.setup_backend_for_grpc(protocol=backend_protocol,
90                                    health_check_port=health_check_port)
91        self.setup_routing_rule_map_for_grpc(service_host, service_port)
92
93    def setup_backend_for_grpc(
94            self,
95            *,
96            protocol: Optional[BackendServiceProtocol] = _BackendGRPC,
97            health_check_port: Optional[int] = None):
98        self.create_health_check(port=health_check_port)
99        self.create_backend_service(protocol)
100
101    def setup_routing_rule_map_for_grpc(self, service_host, service_port):
102        self.create_url_map(service_host, service_port)
103        self.create_target_proxy()
104        self.create_forwarding_rule(service_port)
105
106    def cleanup(self, *, force=False):
107        # Cleanup in the reverse order of creation
108        self.delete_forwarding_rule(force=force)
109        if self.target_proxy_is_http:
110            self.delete_target_http_proxy(force=force)
111        else:
112            self.delete_target_grpc_proxy(force=force)
113        self.delete_url_map(force=force)
114        self.delete_backend_service(force=force)
115        self.delete_health_check(force=force)
116
117    def _ns_name(self, name):
118        return f'{self.resource_prefix}-{name}'
119
120    def create_health_check(
121            self,
122            *,
123            protocol: Optional[HealthCheckProtocol] = _HealthCheckGRPC,
124            port: Optional[int] = None):
125        if self.health_check:
126            raise ValueError(f'Health check {self.health_check.name} '
127                             'already created, delete it first')
128        if protocol is None:
129            protocol = _HealthCheckGRPC
130
131        name = self._ns_name(self.HEALTH_CHECK_NAME)
132        logger.info('Creating %s Health Check "%s"', protocol.name, name)
133        resource = self.compute.create_health_check(name, protocol, port=port)
134        self.health_check = resource
135
136    def delete_health_check(self, force=False):
137        if force:
138            name = self._ns_name(self.HEALTH_CHECK_NAME)
139        elif self.health_check:
140            name = self.health_check.name
141        else:
142            return
143        logger.info('Deleting Health Check "%s"', name)
144        self.compute.delete_health_check(name)
145        self.health_check = None
146
147    def create_backend_service(
148            self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
149        if protocol is None:
150            protocol = _BackendGRPC
151
152        name = self._ns_name(self.BACKEND_SERVICE_NAME)
153        logger.info('Creating %s Backend Service "%s"', protocol.name, name)
154        resource = self.compute.create_backend_service_traffic_director(
155            name, health_check=self.health_check, protocol=protocol)
156        self.backend_service = resource
157        self.backend_service_protocol = protocol
158
159    def load_backend_service(self):
160        name = self._ns_name(self.BACKEND_SERVICE_NAME)
161        resource = self.compute.get_backend_service_traffic_director(name)
162        self.backend_service = resource
163
164    def delete_backend_service(self, force=False):
165        if force:
166            name = self._ns_name(self.BACKEND_SERVICE_NAME)
167        elif self.backend_service:
168            name = self.backend_service.name
169        else:
170            return
171        logger.info('Deleting Backend Service "%s"', name)
172        self.compute.delete_backend_service(name)
173        self.backend_service = None
174
175    def backend_service_add_neg_backends(self, name, zones):
176        logger.info('Waiting for Network Endpoint Groups to load endpoints.')
177        for zone in zones:
178            backend = self.compute.wait_for_network_endpoint_group(name, zone)
179            logger.info('Loaded NEG "%s" in zone %s', backend.name,
180                        backend.zone)
181            self.backends.add(backend)
182        self.backend_service_add_backends()
183
184    def backend_service_add_backends(self):
185        logging.info('Adding backends to Backend Service %s: %r',
186                     self.backend_service.name, self.backends)
187        self.compute.backend_service_add_backends(self.backend_service,
188                                                  self.backends)
189
190    def backend_service_remove_all_backends(self):
191        logging.info('Removing backends from Backend Service %s',
192                     self.backend_service.name)
193        self.compute.backend_service_remove_all_backends(self.backend_service)
194
195    def wait_for_backends_healthy_status(self):
196        logger.debug(
197            "Waiting for Backend Service %s to report all backends healthy %r",
198            self.backend_service, self.backends)
199        self.compute.wait_for_backends_healthy_status(self.backend_service,
200                                                      self.backends)
201
202    def create_url_map(
203        self,
204        src_host: str,
205        src_port: int,
206    ) -> GcpResource:
207        src_address = f'{src_host}:{src_port}'
208        name = self._ns_name(self.URL_MAP_NAME)
209        matcher_name = self._ns_name(self.URL_MAP_PATH_MATCHER_NAME)
210        logger.info('Creating URL map "%s": %s -> %s', name, src_address,
211                    self.backend_service.name)
212        resource = self.compute.create_url_map(name, matcher_name,
213                                               [src_address],
214                                               self.backend_service)
215        self.url_map = resource
216        return resource
217
218    def delete_url_map(self, force=False):
219        if force:
220            name = self._ns_name(self.URL_MAP_NAME)
221        elif self.url_map:
222            name = self.url_map.name
223        else:
224            return
225        logger.info('Deleting URL Map "%s"', name)
226        self.compute.delete_url_map(name)
227        self.url_map = None
228
229    def create_target_proxy(self):
230        name = self._ns_name(self.TARGET_PROXY_NAME)
231        if self.backend_service_protocol is BackendServiceProtocol.GRPC:
232            target_proxy_type = 'GRPC'
233            create_proxy_fn = self.compute.create_target_grpc_proxy
234            self.target_proxy_is_http = False
235        elif self.backend_service_protocol is BackendServiceProtocol.HTTP2:
236            target_proxy_type = 'HTTP'
237            create_proxy_fn = self.compute.create_target_http_proxy
238            self.target_proxy_is_http = True
239        else:
240            raise TypeError('Unexpected backend service protocol')
241
242        logger.info('Creating target %s proxy "%s" to URL map %s', name,
243                    target_proxy_type, self.url_map.name)
244        self.target_proxy = create_proxy_fn(name, self.url_map)
245
246    def delete_target_grpc_proxy(self, force=False):
247        if force:
248            name = self._ns_name(self.TARGET_PROXY_NAME)
249        elif self.target_proxy:
250            name = self.target_proxy.name
251        else:
252            return
253        logger.info('Deleting Target GRPC proxy "%s"', name)
254        self.compute.delete_target_grpc_proxy(name)
255        self.target_proxy = None
256        self.target_proxy_is_http = False
257
258    def delete_target_http_proxy(self, force=False):
259        if force:
260            name = self._ns_name(self.TARGET_PROXY_NAME)
261        elif self.target_proxy:
262            name = self.target_proxy.name
263        else:
264            return
265        logger.info('Deleting HTTP Target proxy "%s"', name)
266        self.compute.delete_target_http_proxy(name)
267        self.target_proxy = None
268        self.target_proxy_is_http = False
269
270    def create_forwarding_rule(self, src_port: int):
271        name = self._ns_name(self.FORWARDING_RULE_NAME)
272        src_port = int(src_port)
273        logging.info(
274            'Creating forwarding rule "%s" in network "%s": 0.0.0.0:%s -> %s',
275            name, self.network, src_port, self.target_proxy.url)
276        resource = self.compute.create_forwarding_rule(name, src_port,
277                                                       self.target_proxy,
278                                                       self.network_url)
279        self.forwarding_rule = resource
280        return resource
281
282    def delete_forwarding_rule(self, force=False):
283        if force:
284            name = self._ns_name(self.FORWARDING_RULE_NAME)
285        elif self.forwarding_rule:
286            name = self.forwarding_rule.name
287        else:
288            return
289        logger.info('Deleting Forwarding rule "%s"', name)
290        self.compute.delete_forwarding_rule(name)
291        self.forwarding_rule = None
292
293
294class TrafficDirectorSecureManager(TrafficDirectorManager):
295    netsec: Optional[_NetworkSecurityV1Alpha1]
296    SERVER_TLS_POLICY_NAME = "server-tls-policy"
297    CLIENT_TLS_POLICY_NAME = "client-tls-policy"
298    ENDPOINT_CONFIG_SELECTOR_NAME = "endpoint-config-selector"
299    CERTIFICATE_PROVIDER_INSTANCE = "google_cloud_private_spiffe"
300
301    def __init__(
302        self,
303        gcp_api_manager: gcp.api.GcpApiManager,
304        project: str,
305        *,
306        resource_prefix: str,
307        network: str = 'default',
308    ):
309        super().__init__(gcp_api_manager,
310                         project,
311                         resource_prefix=resource_prefix,
312                         network=network)
313
314        # API
315        self.netsec = _NetworkSecurityV1Alpha1(gcp_api_manager, project)
316        self.netsvc = _NetworkServicesV1Alpha1(gcp_api_manager, project)
317
318        # Managed resources
319        self.server_tls_policy: Optional[ServerTlsPolicy] = None
320        self.ecs: Optional[EndpointConfigSelector] = None
321        self.client_tls_policy: Optional[ClientTlsPolicy] = None
322
323    def setup_server_security(self,
324                              *,
325                              server_namespace,
326                              server_name,
327                              server_port,
328                              tls=True,
329                              mtls=True):
330        self.create_server_tls_policy(tls=tls, mtls=mtls)
331        self.create_endpoint_config_selector(server_namespace=server_namespace,
332                                             server_name=server_name,
333                                             server_port=server_port)
334
335    def setup_client_security(self,
336                              *,
337                              server_namespace,
338                              server_name,
339                              tls=True,
340                              mtls=True):
341        self.create_client_tls_policy(tls=tls, mtls=mtls)
342        self.backend_service_apply_client_mtls_policy(server_namespace,
343                                                      server_name)
344
345    def cleanup(self, *, force=False):
346        # Cleanup in the reverse order of creation
347        super().cleanup(force=force)
348        self.delete_endpoint_config_selector(force=force)
349        self.delete_server_tls_policy(force=force)
350        self.delete_client_tls_policy(force=force)
351
352    def create_server_tls_policy(self, *, tls, mtls):
353        name = self._ns_name(self.SERVER_TLS_POLICY_NAME)
354        logger.info('Creating Server TLS Policy %s', name)
355        if not tls and not mtls:
356            logger.warning(
357                'Server TLS Policy %s neither TLS, nor mTLS '
358                'policy. Skipping creation', name)
359            return
360
361        certificate_provider = self._get_certificate_provider()
362        policy = {}
363        if tls:
364            policy["serverCertificate"] = certificate_provider
365        if mtls:
366            policy["mtlsPolicy"] = {
367                "clientValidationCa": [certificate_provider],
368            }
369
370        self.netsec.create_server_tls_policy(name, policy)
371        self.server_tls_policy = self.netsec.get_server_tls_policy(name)
372        logger.debug('Server TLS Policy loaded: %r', self.server_tls_policy)
373
374    def delete_server_tls_policy(self, force=False):
375        if force:
376            name = self._ns_name(self.SERVER_TLS_POLICY_NAME)
377        elif self.server_tls_policy:
378            name = self.server_tls_policy.name
379        else:
380            return
381        logger.info('Deleting Server TLS Policy %s', name)
382        self.netsec.delete_server_tls_policy(name)
383        self.server_tls_policy = None
384
385    def create_endpoint_config_selector(self, server_namespace, server_name,
386                                        server_port):
387        name = self._ns_name(self.ENDPOINT_CONFIG_SELECTOR_NAME)
388        logger.info('Creating Endpoint Config Selector %s', name)
389        endpoint_matcher_labels = [{
390            "labelName": "app",
391            "labelValue": f"{server_namespace}-{server_name}"
392        }]
393        port_selector = {"ports": [str(server_port)]}
394        label_matcher_all = {
395            "metadataLabelMatchCriteria": "MATCH_ALL",
396            "metadataLabels": endpoint_matcher_labels
397        }
398        config = {
399            "type": "GRPC_SERVER",
400            "httpFilters": {},
401            "trafficPortSelector": port_selector,
402            "endpointMatcher": {
403                "metadataLabelMatcher": label_matcher_all
404            },
405        }
406        if self.server_tls_policy:
407            config["serverTlsPolicy"] = self.server_tls_policy.name
408        else:
409            logger.warning(
410                'Creating Endpoint Config Selector %s with '
411                'no Server TLS policy attached', name)
412
413        self.netsvc.create_endpoint_config_selector(name, config)
414        self.ecs = self.netsvc.get_endpoint_config_selector(name)
415        logger.debug('Loaded Endpoint Config Selector: %r', self.ecs)
416
417    def delete_endpoint_config_selector(self, force=False):
418        if force:
419            name = self._ns_name(self.ENDPOINT_CONFIG_SELECTOR_NAME)
420        elif self.ecs:
421            name = self.ecs.name
422        else:
423            return
424        logger.info('Deleting Endpoint Config Selector %s', name)
425        self.netsvc.delete_endpoint_config_selector(name)
426        self.ecs = None
427
428    def create_client_tls_policy(self, *, tls, mtls):
429        name = self._ns_name(self.CLIENT_TLS_POLICY_NAME)
430        logger.info('Creating Client TLS Policy %s', name)
431        if not tls and not mtls:
432            logger.warning(
433                'Client TLS Policy %s neither TLS, nor mTLS '
434                'policy. Skipping creation', name)
435            return
436
437        certificate_provider = self._get_certificate_provider()
438        policy = {}
439        if tls:
440            policy["serverValidationCa"] = [certificate_provider]
441        if mtls:
442            policy["clientCertificate"] = certificate_provider
443
444        self.netsec.create_client_tls_policy(name, policy)
445        self.client_tls_policy = self.netsec.get_client_tls_policy(name)
446        logger.debug('Client TLS Policy loaded: %r', self.client_tls_policy)
447
448    def delete_client_tls_policy(self, force=False):
449        if force:
450            name = self._ns_name(self.CLIENT_TLS_POLICY_NAME)
451        elif self.client_tls_policy:
452            name = self.client_tls_policy.name
453        else:
454            return
455        logger.info('Deleting Client TLS Policy %s', name)
456        self.netsec.delete_client_tls_policy(name)
457        self.client_tls_policy = None
458
459    def backend_service_apply_client_mtls_policy(
460        self,
461        server_namespace,
462        server_name,
463    ):
464        if not self.client_tls_policy:
465            logger.warning(
466                'Client TLS policy not created, '
467                'skipping attaching to Backend Service %s',
468                self.backend_service.name)
469            return
470
471        server_spiffe = (f'spiffe://{self.project}.svc.id.goog/'
472                         f'ns/{server_namespace}/sa/{server_name}')
473        logging.info(
474            'Adding Client TLS Policy to Backend Service %s: %s, '
475            'server %s', self.backend_service.name, self.client_tls_policy.url,
476            server_spiffe)
477
478        self.compute.patch_backend_service(
479            self.backend_service, {
480                'securitySettings': {
481                    'clientTlsPolicy': self.client_tls_policy.url,
482                    'subjectAltNames': [server_spiffe]
483                }
484            })
485
486    @classmethod
487    def _get_certificate_provider(cls):
488        return {
489            "certificateProviderInstance": {
490                "pluginInstance": cls.CERTIFICATE_PROVIDER_INSTANCE,
491            },
492        }
493