• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import collections
17import datetime
18import logging
19import signal
20import threading
21import time
22import sys
23
24from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple
25import collections
26
27from concurrent import futures
28
29import grpc
30
31from src.proto.grpc.testing import test_pb2
32from src.proto.grpc.testing import test_pb2_grpc
33from src.proto.grpc.testing import messages_pb2
34from src.proto.grpc.testing import empty_pb2
35
36logger = logging.getLogger()
37console_handler = logging.StreamHandler()
38formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
39console_handler.setFormatter(formatter)
40logger.addHandler(console_handler)
41
42_SUPPORTED_METHODS = (
43    "UnaryCall",
44    "EmptyCall",
45)
46
47_METHOD_CAMEL_TO_CAPS_SNAKE = {
48    "UnaryCall": "UNARY_CALL",
49    "EmptyCall": "EMPTY_CALL",
50}
51
52_METHOD_STR_TO_ENUM = {
53    "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
54    "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
55}
56
57_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
58
59PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
60
61_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
62
63
64class _StatsWatcher:
65    _start: int
66    _end: int
67    _rpcs_needed: int
68    _rpcs_by_peer: DefaultDict[str, int]
69    _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]]
70    _no_remote_peer: int
71    _lock: threading.Lock
72    _condition: threading.Condition
73
74    def __init__(self, start: int, end: int):
75        self._start = start
76        self._end = end
77        self._rpcs_needed = end - start
78        self._rpcs_by_peer = collections.defaultdict(int)
79        self._rpcs_by_method = collections.defaultdict(
80            lambda: collections.defaultdict(int))
81        self._condition = threading.Condition()
82        self._no_remote_peer = 0
83
84    def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
85        """Records statistics for a single RPC."""
86        if self._start <= request_id < self._end:
87            with self._condition:
88                if not peer:
89                    self._no_remote_peer += 1
90                else:
91                    self._rpcs_by_peer[peer] += 1
92                    self._rpcs_by_method[method][peer] += 1
93                self._rpcs_needed -= 1
94                self._condition.notify()
95
96    def await_rpc_stats_response(
97            self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse:
98        """Blocks until a full response has been collected."""
99        with self._condition:
100            self._condition.wait_for(lambda: not self._rpcs_needed,
101                                     timeout=float(timeout_sec))
102            response = messages_pb2.LoadBalancerStatsResponse()
103            for peer, count in self._rpcs_by_peer.items():
104                response.rpcs_by_peer[peer] = count
105            for method, count_by_peer in self._rpcs_by_method.items():
106                for peer, count in count_by_peer.items():
107                    response.rpcs_by_method[method].rpcs_by_peer[peer] = count
108            response.num_failures = self._no_remote_peer + self._rpcs_needed
109        return response
110
111
112_global_lock = threading.Lock()
113_stop_event = threading.Event()
114_global_rpc_id: int = 0
115_watchers: Set[_StatsWatcher] = set()
116_global_server = None
117_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
118_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
119_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
120
121# Mapping[method, Mapping[status_code, count]]
122_global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict(
123    lambda: collections.defaultdict(int))
124
125
126def _handle_sigint(sig, frame) -> None:
127    _stop_event.set()
128    _global_server.stop(None)
129
130
131class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
132                                ):
133
134    def __init__(self):
135        super(_LoadBalancerStatsServicer).__init__()
136
137    def GetClientStats(
138        self, request: messages_pb2.LoadBalancerStatsRequest,
139        context: grpc.ServicerContext
140    ) -> messages_pb2.LoadBalancerStatsResponse:
141        logger.info("Received stats request.")
142        start = None
143        end = None
144        watcher = None
145        with _global_lock:
146            start = _global_rpc_id + 1
147            end = start + request.num_rpcs
148            watcher = _StatsWatcher(start, end)
149            _watchers.add(watcher)
150        response = watcher.await_rpc_stats_response(request.timeout_sec)
151        with _global_lock:
152            _watchers.remove(watcher)
153        logger.info("Returning stats response: %s", response)
154        return response
155
156    def GetClientAccumulatedStats(
157        self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
158        context: grpc.ServicerContext
159    ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
160        logger.info("Received cumulative stats request.")
161        response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
162        with _global_lock:
163            for method in _SUPPORTED_METHODS:
164                caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
165                response.num_rpcs_started_by_method[
166                    caps_method] = _global_rpcs_started[method]
167                response.num_rpcs_succeeded_by_method[
168                    caps_method] = _global_rpcs_succeeded[method]
169                response.num_rpcs_failed_by_method[
170                    caps_method] = _global_rpcs_failed[method]
171                response.stats_per_method[
172                    caps_method].rpcs_started = _global_rpcs_started[method]
173                for code, count in _global_rpc_statuses[method].items():
174                    response.stats_per_method[caps_method].result[code] = count
175        logger.info("Returning cumulative stats response.")
176        return response
177
178
179def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]],
180               request_id: int, stub: test_pb2_grpc.TestServiceStub,
181               timeout: float, futures: Mapping[int, Tuple[grpc.Future,
182                                                           str]]) -> None:
183    logger.info(f"Sending {method} request to backend: {request_id}")
184    if method == "UnaryCall":
185        future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
186                                       metadata=metadata,
187                                       timeout=timeout)
188    elif method == "EmptyCall":
189        future = stub.EmptyCall.future(empty_pb2.Empty(),
190                                       metadata=metadata,
191                                       timeout=timeout)
192    else:
193        raise ValueError(f"Unrecognized method '{method}'.")
194    futures[request_id] = (future, method)
195
196
197def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
198                 print_response: bool) -> None:
199    exception = future.exception()
200    hostname = ""
201    _global_rpc_statuses[method][future.code().value[0]] += 1
202    if exception is not None:
203        with _global_lock:
204            _global_rpcs_failed[method] += 1
205        if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
206            logger.error(f"RPC {rpc_id} timed out")
207        else:
208            logger.error(exception)
209    else:
210        response = future.result()
211        hostname = None
212        for metadatum in future.initial_metadata():
213            if metadatum[0] == "hostname":
214                hostname = metadatum[1]
215                break
216        else:
217            hostname = response.hostname
218        if future.code() == grpc.StatusCode.OK:
219            with _global_lock:
220                _global_rpcs_succeeded[method] += 1
221        else:
222            with _global_lock:
223                _global_rpcs_failed[method] += 1
224        if print_response:
225            if future.code() == grpc.StatusCode.OK:
226                logger.info("Successful response.")
227            else:
228                logger.info(f"RPC failed: {call}")
229    with _global_lock:
230        for watcher in _watchers:
231            watcher.on_rpc_complete(rpc_id, hostname, method)
232
233
234def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
235                           print_response: bool) -> None:
236    logger.debug("Removing completed RPCs")
237    done = []
238    for future_id, (future, method) in futures.items():
239        if future.done():
240            _on_rpc_done(future_id, future, method, args.print_response)
241            done.append(future_id)
242    for rpc_id in done:
243        del futures[rpc_id]
244
245
246def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
247    logger.info("Cancelling all remaining RPCs")
248    for future, _ in futures.values():
249        future.cancel()
250
251
252class _ChannelConfiguration:
253    """Configuration for a single client channel.
254
255    Instances of this class are meant to be dealt with as PODs. That is,
256    data member should be accessed directly. This class is not thread-safe.
257    When accessing any of its members, the lock member should be held.
258    """
259
260    def __init__(self, method: str, metadata: Sequence[Tuple[str,
261                                                             str]], qps: int,
262                 server: str, rpc_timeout_sec: int, print_response: bool):
263        # condition is signalled when a change is made to the config.
264        self.condition = threading.Condition()
265
266        self.method = method
267        self.metadata = metadata
268        self.qps = qps
269        self.server = server
270        self.rpc_timeout_sec = rpc_timeout_sec
271        self.print_response = print_response
272
273
274def _run_single_channel(config: _ChannelConfiguration) -> None:
275    global _global_rpc_id  # pylint: disable=global-statement
276    with config.condition:
277        server = config.server
278    with grpc.insecure_channel(server) as channel:
279        stub = test_pb2_grpc.TestServiceStub(channel)
280        futures: Dict[int, Tuple[grpc.Future, str]] = {}
281        while not _stop_event.is_set():
282            with config.condition:
283                if config.qps == 0:
284                    config.condition.wait(
285                        timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
286                    continue
287                else:
288                    duration_per_query = 1.0 / float(config.qps)
289            request_id = None
290            with _global_lock:
291                request_id = _global_rpc_id
292                _global_rpc_id += 1
293                _global_rpcs_started[config.method] += 1
294            start = time.time()
295            end = start + duration_per_query
296            with config.condition:
297                _start_rpc(config.method, config.metadata, request_id, stub,
298                           float(config.rpc_timeout_sec), futures)
299            with config.condition:
300                _remove_completed_rpcs(futures, config.print_response)
301            logger.debug(f"Currently {len(futures)} in-flight RPCs")
302            now = time.time()
303            while now < end:
304                time.sleep(end - now)
305                now = time.time()
306        _cancel_all_rpcs(futures)
307
308
309class _XdsUpdateClientConfigureServicer(
310        test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
311
312    def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration],
313                 qps: int):
314        super(_XdsUpdateClientConfigureServicer).__init__()
315        self._per_method_configs = per_method_configs
316        self._qps = qps
317
318    def Configure(
319            self, request: messages_pb2.ClientConfigureRequest,
320            context: grpc.ServicerContext
321    ) -> messages_pb2.ClientConfigureResponse:
322        logger.info("Received Configure RPC: %s", request)
323        method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types)
324        for method in _SUPPORTED_METHODS:
325            method_enum = _METHOD_STR_TO_ENUM[method]
326            channel_config = self._per_method_configs[method]
327            if method in method_strs:
328                qps = self._qps
329                metadata = ((md.key, md.value)
330                            for md in request.metadata
331                            if md.type == method_enum)
332                # For backward compatibility, do not change timeout when we
333                # receive a default value timeout.
334                if request.timeout_sec == 0:
335                    timeout_sec = channel_config.rpc_timeout_sec
336                else:
337                    timeout_sec = request.timeout_sec
338            else:
339                qps = 0
340                metadata = ()
341                # Leave timeout unchanged for backward compatibility.
342                timeout_sec = channel_config.rpc_timeout_sec
343            with channel_config.condition:
344                channel_config.qps = qps
345                channel_config.metadata = list(metadata)
346                channel_config.rpc_timeout_sec = timeout_sec
347                channel_config.condition.notify_all()
348        return messages_pb2.ClientConfigureResponse()
349
350
351class _MethodHandle:
352    """An object grouping together threads driving RPCs for a method."""
353
354    _channel_threads: List[threading.Thread]
355
356    def __init__(self, num_channels: int,
357                 channel_config: _ChannelConfiguration):
358        """Creates and starts a group of threads running the indicated method."""
359        self._channel_threads = []
360        for i in range(num_channels):
361            thread = threading.Thread(target=_run_single_channel,
362                                      args=(channel_config,))
363            thread.start()
364            self._channel_threads.append(thread)
365
366    def stop(self) -> None:
367        """Joins all threads referenced by the handle."""
368        for channel_thread in self._channel_threads:
369            channel_thread.join()
370
371
372def _run(args: argparse.Namespace, methods: Sequence[str],
373         per_method_metadata: PerMethodMetadataType) -> None:
374    logger.info("Starting python xDS Interop Client.")
375    global _global_server  # pylint: disable=global-statement
376    method_handles = []
377    channel_configs = {}
378    for method in _SUPPORTED_METHODS:
379        if method in methods:
380            qps = args.qps
381        else:
382            qps = 0
383        channel_config = _ChannelConfiguration(
384            method, per_method_metadata.get(method, []), qps, args.server,
385            args.rpc_timeout_sec, args.print_response)
386        channel_configs[method] = channel_config
387        method_handles.append(_MethodHandle(args.num_channels, channel_config))
388    _global_server = grpc.server(futures.ThreadPoolExecutor())
389    _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
390    test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
391        _LoadBalancerStatsServicer(), _global_server)
392    test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
393        _XdsUpdateClientConfigureServicer(channel_configs, args.qps),
394        _global_server)
395    _global_server.start()
396    _global_server.wait_for_termination()
397    for method_handle in method_handles:
398        method_handle.stop()
399
400
401def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType:
402    metadata = metadata_arg.split(",") if args.metadata else []
403    per_method_metadata = collections.defaultdict(list)
404    for metadatum in metadata:
405        elems = metadatum.split(":")
406        if len(elems) != 3:
407            raise ValueError(
408                f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
409        if elems[0] not in _SUPPORTED_METHODS:
410            raise ValueError(f"Unrecognized method '{elems[0]}'")
411        per_method_metadata[elems[0]].append((elems[1], elems[2]))
412    return per_method_metadata
413
414
415def parse_rpc_arg(rpc_arg: str) -> Sequence[str]:
416    methods = rpc_arg.split(",")
417    if set(methods) - set(_SUPPORTED_METHODS):
418        raise ValueError("--rpc supported methods: {}".format(
419            ", ".join(_SUPPORTED_METHODS)))
420    return methods
421
422
423if __name__ == "__main__":
424    parser = argparse.ArgumentParser(
425        description='Run Python XDS interop client.')
426    parser.add_argument(
427        "--num_channels",
428        default=1,
429        type=int,
430        help="The number of channels from which to send requests.")
431    parser.add_argument("--print_response",
432                        default=False,
433                        action="store_true",
434                        help="Write RPC response to STDOUT.")
435    parser.add_argument(
436        "--qps",
437        default=1,
438        type=int,
439        help="The number of queries to send from each channel per second.")
440    parser.add_argument("--rpc_timeout_sec",
441                        default=30,
442                        type=int,
443                        help="The per-RPC timeout in seconds.")
444    parser.add_argument("--server",
445                        default="localhost:50051",
446                        help="The address of the server.")
447    parser.add_argument(
448        "--stats_port",
449        default=50052,
450        type=int,
451        help="The port on which to expose the peer distribution stats service.")
452    parser.add_argument('--verbose',
453                        help='verbose log output',
454                        default=False,
455                        action='store_true')
456    parser.add_argument("--log_file",
457                        default=None,
458                        type=str,
459                        help="A file to log to.")
460    rpc_help = "A comma-delimited list of RPC methods to run. Must be one of "
461    rpc_help += ", ".join(_SUPPORTED_METHODS)
462    rpc_help += "."
463    parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help)
464    metadata_help = (
465        "A comma-delimited list of 3-tuples of the form " +
466        "METHOD:KEY:VALUE, e.g. " +
467        "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
468    parser.add_argument("--metadata", default="", type=str, help=metadata_help)
469    args = parser.parse_args()
470    signal.signal(signal.SIGINT, _handle_sigint)
471    if args.verbose:
472        logger.setLevel(logging.DEBUG)
473    if args.log_file:
474        file_handler = logging.FileHandler(args.log_file, mode='a')
475        file_handler.setFormatter(formatter)
476        logger.addHandler(file_handler)
477    _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata))
478