• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import functools
15import json
16import logging
17import subprocess
18import time
19from typing import Optional, List, Tuple
20
21# TODO(sergiitk): replace with tenacity
22import retrying
23import kubernetes.config
24from kubernetes import client
25from kubernetes import utils
26
27logger = logging.getLogger(__name__)
28# Type aliases
29V1Deployment = client.V1Deployment
30V1ServiceAccount = client.V1ServiceAccount
31V1Pod = client.V1Pod
32V1PodList = client.V1PodList
33V1Service = client.V1Service
34V1Namespace = client.V1Namespace
35ApiException = client.ApiException
36
37
38def simple_resource_get(func):
39
40    def wrap_not_found_return_none(*args, **kwargs):
41        try:
42            return func(*args, **kwargs)
43        except client.ApiException as e:
44            if e.status == 404:
45                # Ignore 404
46                return None
47            raise
48
49    return wrap_not_found_return_none
50
51
52def label_dict_to_selector(labels: dict) -> str:
53    return ','.join(f'{k}=={v}' for k, v in labels.items())
54
55
56class KubernetesApiManager:
57
58    def __init__(self, context):
59        self.context = context
60        self.client = self._cached_api_client_for_context(context)
61        self.apps = client.AppsV1Api(self.client)
62        self.core = client.CoreV1Api(self.client)
63
64    def close(self):
65        self.client.close()
66
67    @classmethod
68    @functools.lru_cache(None)
69    def _cached_api_client_for_context(cls, context: str) -> client.ApiClient:
70        client_instance = kubernetes.config.new_client_from_config(
71            context=context)
72        logger.info('Using kubernetes context "%s", active host: %s', context,
73                    client_instance.configuration.host)
74        return client_instance
75
76
77class PortForwardingError(Exception):
78    """Error forwarding port"""
79
80
81class KubernetesNamespace:
82    NEG_STATUS_META = 'cloud.google.com/neg-status'
83    PORT_FORWARD_LOCAL_ADDRESS: str = '127.0.0.1'
84    DELETE_GRACE_PERIOD_SEC: int = 5
85    WAIT_SHORT_TIMEOUT_SEC: int = 60
86    WAIT_SHORT_SLEEP_SEC: int = 1
87    WAIT_MEDIUM_TIMEOUT_SEC: int = 5 * 60
88    WAIT_MEDIUM_SLEEP_SEC: int = 10
89    WAIT_LONG_TIMEOUT_SEC: int = 10 * 60
90    WAIT_LONG_SLEEP_SEC: int = 30
91
92    def __init__(self, api: KubernetesApiManager, name: str):
93        self.name = name
94        self.api = api
95
96    def apply_manifest(self, manifest):
97        return utils.create_from_dict(self.api.client,
98                                      manifest,
99                                      namespace=self.name)
100
101    @simple_resource_get
102    def get_service(self, name) -> V1Service:
103        return self.api.core.read_namespaced_service(name, self.name)
104
105    @simple_resource_get
106    def get_service_account(self, name) -> V1Service:
107        return self.api.core.read_namespaced_service_account(name, self.name)
108
109    def delete_service(self,
110                       name,
111                       grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
112        self.api.core.delete_namespaced_service(
113            name=name,
114            namespace=self.name,
115            body=client.V1DeleteOptions(
116                propagation_policy='Foreground',
117                grace_period_seconds=grace_period_seconds))
118
119    def delete_service_account(self,
120                               name,
121                               grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
122        self.api.core.delete_namespaced_service_account(
123            name=name,
124            namespace=self.name,
125            body=client.V1DeleteOptions(
126                propagation_policy='Foreground',
127                grace_period_seconds=grace_period_seconds))
128
129    @simple_resource_get
130    def get(self) -> V1Namespace:
131        return self.api.core.read_namespace(self.name)
132
133    def delete(self, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
134        self.api.core.delete_namespace(
135            name=self.name,
136            body=client.V1DeleteOptions(
137                propagation_policy='Foreground',
138                grace_period_seconds=grace_period_seconds))
139
140    def wait_for_service_deleted(self,
141                                 name: str,
142                                 timeout_sec=WAIT_SHORT_TIMEOUT_SEC,
143                                 wait_sec=WAIT_SHORT_SLEEP_SEC):
144
145        @retrying.retry(retry_on_result=lambda r: r is not None,
146                        stop_max_delay=timeout_sec * 1000,
147                        wait_fixed=wait_sec * 1000)
148        def _wait_for_deleted_service_with_retry():
149            service = self.get_service(name)
150            if service is not None:
151                logger.debug('Waiting for service %s to be deleted',
152                             service.metadata.name)
153            return service
154
155        _wait_for_deleted_service_with_retry()
156
157    def wait_for_service_account_deleted(self,
158                                         name: str,
159                                         timeout_sec=WAIT_SHORT_TIMEOUT_SEC,
160                                         wait_sec=WAIT_SHORT_SLEEP_SEC):
161
162        @retrying.retry(retry_on_result=lambda r: r is not None,
163                        stop_max_delay=timeout_sec * 1000,
164                        wait_fixed=wait_sec * 1000)
165        def _wait_for_deleted_service_account_with_retry():
166            service_account = self.get_service_account(name)
167            if service_account is not None:
168                logger.debug('Waiting for service account %s to be deleted',
169                             service_account.metadata.name)
170            return service_account
171
172        _wait_for_deleted_service_account_with_retry()
173
174    def wait_for_namespace_deleted(self,
175                                   timeout_sec=WAIT_LONG_TIMEOUT_SEC,
176                                   wait_sec=WAIT_LONG_SLEEP_SEC):
177
178        @retrying.retry(retry_on_result=lambda r: r is not None,
179                        stop_max_delay=timeout_sec * 1000,
180                        wait_fixed=wait_sec * 1000)
181        def _wait_for_deleted_namespace_with_retry():
182            namespace = self.get()
183            if namespace is not None:
184                logger.debug('Waiting for namespace %s to be deleted',
185                             namespace.metadata.name)
186            return namespace
187
188        _wait_for_deleted_namespace_with_retry()
189
190    def wait_for_service_neg(self,
191                             name: str,
192                             timeout_sec=WAIT_SHORT_TIMEOUT_SEC,
193                             wait_sec=WAIT_SHORT_SLEEP_SEC):
194
195        @retrying.retry(retry_on_result=lambda r: not r,
196                        stop_max_delay=timeout_sec * 1000,
197                        wait_fixed=wait_sec * 1000)
198        def _wait_for_service_neg():
199            service = self.get_service(name)
200            if self.NEG_STATUS_META not in service.metadata.annotations:
201                logger.debug('Waiting for service %s NEG',
202                             service.metadata.name)
203                return False
204            return True
205
206        _wait_for_service_neg()
207
208    def get_service_neg(self, service_name: str,
209                        service_port: int) -> Tuple[str, List[str]]:
210        service = self.get_service(service_name)
211        neg_info: dict = json.loads(
212            service.metadata.annotations[self.NEG_STATUS_META])
213        neg_name: str = neg_info['network_endpoint_groups'][str(service_port)]
214        neg_zones: List[str] = neg_info['zones']
215        return neg_name, neg_zones
216
217    @simple_resource_get
218    def get_deployment(self, name) -> V1Deployment:
219        return self.api.apps.read_namespaced_deployment(name, self.name)
220
221    def delete_deployment(self,
222                          name,
223                          grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
224        self.api.apps.delete_namespaced_deployment(
225            name=name,
226            namespace=self.name,
227            body=client.V1DeleteOptions(
228                propagation_policy='Foreground',
229                grace_period_seconds=grace_period_seconds))
230
231    def list_deployment_pods(self, deployment: V1Deployment) -> List[V1Pod]:
232        # V1LabelSelector.match_expressions not supported at the moment
233        return self.list_pods_with_labels(deployment.spec.selector.match_labels)
234
235    def wait_for_deployment_available_replicas(
236            self,
237            name,
238            count=1,
239            timeout_sec=WAIT_MEDIUM_TIMEOUT_SEC,
240            wait_sec=WAIT_MEDIUM_SLEEP_SEC):
241
242        @retrying.retry(
243            retry_on_result=lambda r: not self._replicas_available(r, count),
244            stop_max_delay=timeout_sec * 1000,
245            wait_fixed=wait_sec * 1000)
246        def _wait_for_deployment_available_replicas():
247            deployment = self.get_deployment(name)
248            logger.debug(
249                'Waiting for deployment %s to have %s available '
250                'replicas, current count %s', deployment.metadata.name, count,
251                deployment.status.available_replicas)
252            return deployment
253
254        _wait_for_deployment_available_replicas()
255
256    def wait_for_deployment_deleted(self,
257                                    deployment_name: str,
258                                    timeout_sec=WAIT_MEDIUM_TIMEOUT_SEC,
259                                    wait_sec=WAIT_MEDIUM_SLEEP_SEC):
260
261        @retrying.retry(retry_on_result=lambda r: r is not None,
262                        stop_max_delay=timeout_sec * 1000,
263                        wait_fixed=wait_sec * 1000)
264        def _wait_for_deleted_deployment_with_retry():
265            deployment = self.get_deployment(deployment_name)
266            if deployment is not None:
267                logger.debug(
268                    'Waiting for deployment %s to be deleted. '
269                    'Non-terminated replicas: %s', deployment.metadata.name,
270                    deployment.status.replicas)
271            return deployment
272
273        _wait_for_deleted_deployment_with_retry()
274
275    def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
276        pod_list: V1PodList = self.api.core.list_namespaced_pod(
277            self.name, label_selector=label_dict_to_selector(labels))
278        return pod_list.items
279
280    def get_pod(self, name) -> client.V1Pod:
281        return self.api.core.read_namespaced_pod(name, self.name)
282
283    def wait_for_pod_started(self,
284                             pod_name,
285                             timeout_sec=WAIT_SHORT_TIMEOUT_SEC,
286                             wait_sec=WAIT_SHORT_SLEEP_SEC):
287
288        @retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
289                        stop_max_delay=timeout_sec * 1000,
290                        wait_fixed=wait_sec * 1000)
291        def _wait_for_pod_started():
292            pod = self.get_pod(pod_name)
293            logger.debug('Waiting for pod %s to start, current phase: %s',
294                         pod.metadata.name, pod.status.phase)
295            return pod
296
297        _wait_for_pod_started()
298
299    def port_forward_pod(
300        self,
301        pod: V1Pod,
302        remote_port: int,
303        local_port: Optional[int] = None,
304        local_address: Optional[str] = None,
305    ) -> subprocess.Popen:
306        """Experimental"""
307        local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
308        local_port = local_port or remote_port
309        cmd = [
310            "kubectl", "--context", self.api.context, "--namespace", self.name,
311            "port-forward", "--address", local_address,
312            f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
313        ]
314        pf = subprocess.Popen(cmd,
315                              stdout=subprocess.PIPE,
316                              stderr=subprocess.STDOUT,
317                              universal_newlines=True)
318        # Wait for stdout line indicating successful start.
319        expected = (f"Forwarding from {local_address}:{local_port}"
320                    f" -> {remote_port}")
321        try:
322            while True:
323                time.sleep(0.05)
324                output = pf.stdout.readline().strip()
325                if not output:
326                    return_code = pf.poll()
327                    if return_code is not None:
328                        errors = [error for error in pf.stdout.readlines()]
329                        raise PortForwardingError(
330                            'Error forwarding port, kubectl return '
331                            f'code {return_code}, output {errors}')
332                elif output != expected:
333                    raise PortForwardingError(
334                        f'Error forwarding port, unexpected output {output}')
335                else:
336                    logger.info(output)
337                    break
338        except Exception:
339            self.port_forward_stop(pf)
340            raise
341
342        # TODO(sergiitk): return new PortForwarder object
343        return pf
344
345    @staticmethod
346    def port_forward_stop(pf):
347        logger.info('Shutting down port forwarding, pid %s', pf.pid)
348        pf.kill()
349        stdout, _stderr = pf.communicate(timeout=5)
350        logger.info('Port forwarding stopped')
351        logger.debug('Port forwarding remaining stdout: %s', stdout)
352
353    @staticmethod
354    def _pod_started(pod: V1Pod):
355        return pod.status.phase not in ('Pending', 'Unknown')
356
357    @staticmethod
358    def _replicas_available(deployment, count):
359        return (deployment is not None and
360                deployment.status.available_replicas is not None and
361                deployment.status.available_replicas >= count)
362