• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15xDS Test Client.
16
17TODO(sergiitk): separate XdsTestClient and KubernetesClientRunner to individual
18modules.
19"""
20import datetime
21import functools
22import logging
23from typing import Iterator, Optional
24
25from framework.helpers import retryers
26from framework.infrastructure import k8s
27import framework.rpc
28from framework.rpc import grpc_channelz
29from framework.rpc import grpc_testing
30from framework.test_app import base_runner
31
32logger = logging.getLogger(__name__)
33
34# Type aliases
35_timedelta = datetime.timedelta
36_LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
37_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
38_ChannelzChannel = grpc_channelz.Channel
39_ChannelzChannelState = grpc_channelz.ChannelState
40_ChannelzSubchannel = grpc_channelz.Subchannel
41_ChannelzSocket = grpc_channelz.Socket
42
43
44class XdsTestClient(framework.rpc.grpc.GrpcApp):
45    """
46    Represents RPC services implemented in Client component of the xds test app.
47    https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
48    """
49
50    def __init__(self,
51                 *,
52                 ip: str,
53                 rpc_port: int,
54                 server_target: str,
55                 rpc_host: Optional[str] = None,
56                 maintenance_port: Optional[int] = None):
57        super().__init__(rpc_host=(rpc_host or ip))
58        self.ip = ip
59        self.rpc_port = rpc_port
60        self.server_target = server_target
61        self.maintenance_port = maintenance_port or rpc_port
62
63    @property
64    @functools.lru_cache(None)
65    def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
66        return _LoadBalancerStatsServiceClient(self._make_channel(
67            self.rpc_port))
68
69    @property
70    @functools.lru_cache(None)
71    def channelz(self) -> _ChannelzServiceClient:
72        return _ChannelzServiceClient(self._make_channel(self.maintenance_port))
73
74    def get_load_balancer_stats(
75        self,
76        *,
77        num_rpcs: int,
78        timeout_sec: Optional[int] = None,
79    ) -> grpc_testing.LoadBalancerStatsResponse:
80        """
81        Shortcut to LoadBalancerStatsServiceClient.get_client_stats()
82        """
83        return self.load_balancer_stats.get_client_stats(
84            num_rpcs=num_rpcs, timeout_sec=timeout_sec)
85
86    def wait_for_active_server_channel(self) -> _ChannelzChannel:
87        """Wait for the channel to the server to transition to READY.
88
89        Raises:
90            GrpcApp.NotFound: If the channel never transitioned to READY.
91        """
92        return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
93
94    def get_active_server_channel_socket(self) -> _ChannelzSocket:
95        channel = self.find_server_channel_with_state(
96            _ChannelzChannelState.READY)
97        # Get the first subchannel of the active channel to the server.
98        logger.debug(
99            'Retrieving client -> server socket, '
100            'channel_id: %s, subchannel: %s', channel.ref.channel_id,
101            channel.subchannel_ref[0].name)
102        subchannel, *subchannels = list(
103            self.channelz.list_channel_subchannels(channel))
104        if subchannels:
105            logger.warning('Unexpected subchannels: %r', subchannels)
106        # Get the first socket of the subchannel
107        socket, *sockets = list(
108            self.channelz.list_subchannels_sockets(subchannel))
109        if sockets:
110            logger.warning('Unexpected sockets: %r', subchannels)
111        logger.debug('Found client -> server socket: %s', socket.ref.name)
112        return socket
113
114    def wait_for_server_channel_state(
115            self,
116            state: _ChannelzChannelState,
117            *,
118            timeout: Optional[_timedelta] = None,
119            rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
120        # When polling for a state, prefer smaller wait times to avoid
121        # exhausting all allowed time on a single long RPC.
122        if rpc_deadline is None:
123            rpc_deadline = _timedelta(seconds=30)
124
125        # Fine-tuned to wait for the channel to the server.
126        retryer = retryers.exponential_retryer_with_timeout(
127            wait_min=_timedelta(seconds=10),
128            wait_max=_timedelta(seconds=25),
129            timeout=_timedelta(minutes=5) if timeout is None else timeout)
130
131        logger.info('Waiting for client %s to report a %s channel to %s',
132                    self.ip, _ChannelzChannelState.Name(state),
133                    self.server_target)
134        channel = retryer(self.find_server_channel_with_state,
135                          state,
136                          rpc_deadline=rpc_deadline)
137        logger.info('Client %s channel to %s transitioned to state %s:\n%s',
138                    self.ip, self.server_target,
139                    _ChannelzChannelState.Name(state), channel)
140        return channel
141
142    def find_server_channel_with_state(
143            self,
144            state: _ChannelzChannelState,
145            *,
146            rpc_deadline: Optional[_timedelta] = None,
147            check_subchannel=True) -> _ChannelzChannel:
148        rpc_params = {}
149        if rpc_deadline is not None:
150            rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
151
152        for channel in self.get_server_channels(**rpc_params):
153            channel_state: _ChannelzChannelState = channel.data.state.state
154            logger.info('Server channel: %s, state: %s', channel.ref.name,
155                        _ChannelzChannelState.Name(channel_state))
156            if channel_state is state:
157                if check_subchannel:
158                    # When requested, check if the channel has at least
159                    # one subchannel in the requested state.
160                    try:
161                        subchannel = self.find_subchannel_with_state(
162                            channel, state, **rpc_params)
163                        logger.info('Found subchannel in state %s: %s',
164                                    _ChannelzChannelState.Name(state),
165                                    subchannel)
166                    except self.NotFound as e:
167                        # Otherwise, keep searching.
168                        logger.info(e.message)
169                        continue
170                return channel
171
172        raise self.NotFound(
173            f'Client has no {_ChannelzChannelState.Name(state)} channel with '
174            'the server')
175
176    def get_server_channels(self, **kwargs) -> Iterator[_ChannelzChannel]:
177        return self.channelz.find_channels_for_target(self.server_target,
178                                                      **kwargs)
179
180    def find_subchannel_with_state(self, channel: _ChannelzChannel,
181                                   state: _ChannelzChannelState,
182                                   **kwargs) -> _ChannelzSubchannel:
183        subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
184        for subchannel in subchannels:
185            if subchannel.data.state.state is state:
186                return subchannel
187
188        raise self.NotFound(
189            f'Not found a {_ChannelzChannelState.Name(state)} '
190            f'subchannel for channel_id {channel.ref.channel_id}')
191
192
193class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
194
195    def __init__(self,
196                 k8s_namespace,
197                 *,
198                 deployment_name,
199                 image_name,
200                 gcp_service_account,
201                 td_bootstrap_image,
202                 xds_server_uri=None,
203                 network='default',
204                 service_account_name=None,
205                 stats_port=8079,
206                 deployment_template='client.deployment.yaml',
207                 service_account_template='service-account.yaml',
208                 reuse_namespace=False,
209                 namespace_template=None,
210                 debug_use_port_forwarding=False):
211        super().__init__(k8s_namespace, namespace_template, reuse_namespace)
212
213        # Settings
214        self.deployment_name = deployment_name
215        self.image_name = image_name
216        self.gcp_service_account = gcp_service_account
217        self.service_account_name = service_account_name or deployment_name
218        self.stats_port = stats_port
219        # xDS bootstrap generator
220        self.td_bootstrap_image = td_bootstrap_image
221        self.xds_server_uri = xds_server_uri
222        self.network = network
223        self.deployment_template = deployment_template
224        self.service_account_template = service_account_template
225        self.debug_use_port_forwarding = debug_use_port_forwarding
226
227        # Mutable state
228        self.deployment: Optional[k8s.V1Deployment] = None
229        self.service_account: Optional[k8s.V1ServiceAccount] = None
230        self.port_forwarder = None
231
232    def run(self,
233            *,
234            server_target,
235            rpc='UnaryCall',
236            qps=25,
237            secure_mode=False,
238            print_response=False) -> XdsTestClient:
239        super().run()
240        # TODO(sergiitk): make rpc UnaryCall enum or get it from proto
241
242        # Create service account
243        self.service_account = self._create_service_account(
244            self.service_account_template,
245            service_account_name=self.service_account_name,
246            namespace_name=self.k8s_namespace.name,
247            gcp_service_account=self.gcp_service_account)
248
249        # Always create a new deployment
250        self.deployment = self._create_deployment(
251            self.deployment_template,
252            deployment_name=self.deployment_name,
253            image_name=self.image_name,
254            namespace_name=self.k8s_namespace.name,
255            service_account_name=self.service_account_name,
256            td_bootstrap_image=self.td_bootstrap_image,
257            xds_server_uri=self.xds_server_uri,
258            network=self.network,
259            stats_port=self.stats_port,
260            server_target=server_target,
261            rpc=rpc,
262            qps=qps,
263            secure_mode=secure_mode,
264            print_response=print_response)
265
266        self._wait_deployment_with_available_replicas(self.deployment_name)
267
268        # Load test client pod. We need only one client at the moment
269        pod = self.k8s_namespace.list_deployment_pods(self.deployment)[0]
270        self._wait_pod_started(pod.metadata.name)
271        pod_ip = pod.status.pod_ip
272        rpc_host = None
273
274        # Experimental, for local debugging.
275        if self.debug_use_port_forwarding:
276            logger.info('LOCAL DEV MODE: Enabling port forwarding to %s:%s',
277                        pod_ip, self.stats_port)
278            self.port_forwarder = self.k8s_namespace.port_forward_pod(
279                pod, remote_port=self.stats_port)
280            rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
281
282        return XdsTestClient(ip=pod_ip,
283                             rpc_port=self.stats_port,
284                             server_target=server_target,
285                             rpc_host=rpc_host)
286
287    def cleanup(self, *, force=False, force_namespace=False):
288        if self.port_forwarder:
289            self.k8s_namespace.port_forward_stop(self.port_forwarder)
290            self.port_forwarder = None
291        if self.deployment or force:
292            self._delete_deployment(self.deployment_name)
293            self.deployment = None
294        if self.service_account or force:
295            self._delete_service_account(self.service_account_name)
296            self.service_account = None
297        super().cleanup(force=force_namespace and force)
298