• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import logging
17import signal
18import threading
19import time
20import sys
21
22from typing import DefaultDict, Dict, List, Mapping, Set
23import collections
24
25from concurrent import futures
26
27import grpc
28
29from src.proto.grpc.testing import test_pb2
30from src.proto.grpc.testing import test_pb2_grpc
31from src.proto.grpc.testing import messages_pb2
32from src.proto.grpc.testing import empty_pb2
33
34logger = logging.getLogger()
35console_handler = logging.StreamHandler()
36formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
37console_handler.setFormatter(formatter)
38logger.addHandler(console_handler)
39
40
41class _StatsWatcher:
42    _start: int
43    _end: int
44    _rpcs_needed: int
45    _rpcs_by_peer: DefaultDict[str, int]
46    _no_remote_peer: int
47    _lock: threading.Lock
48    _condition: threading.Condition
49
50    def __init__(self, start: int, end: int):
51        self._start = start
52        self._end = end
53        self._rpcs_needed = end - start
54        self._rpcs_by_peer = collections.defaultdict(int)
55        self._condition = threading.Condition()
56        self._no_remote_peer = 0
57
58    def on_rpc_complete(self, request_id: int, peer: str) -> None:
59        """Records statistics for a single RPC."""
60        if self._start <= request_id < self._end:
61            with self._condition:
62                if not peer:
63                    self._no_remote_peer += 1
64                else:
65                    self._rpcs_by_peer[peer] += 1
66                self._rpcs_needed -= 1
67                self._condition.notify()
68
69    def await_rpc_stats_response(self, timeout_sec: int
70                                ) -> messages_pb2.LoadBalancerStatsResponse:
71        """Blocks until a full response has been collected."""
72        with self._condition:
73            self._condition.wait_for(lambda: not self._rpcs_needed,
74                                     timeout=float(timeout_sec))
75            response = messages_pb2.LoadBalancerStatsResponse()
76            for peer, count in self._rpcs_by_peer.items():
77                response.rpcs_by_peer[peer] = count
78            response.num_failures = self._no_remote_peer + self._rpcs_needed
79        return response
80
81
82_global_lock = threading.Lock()
83_stop_event = threading.Event()
84_global_rpc_id: int = 0
85_watchers: Set[_StatsWatcher] = set()
86_global_server = None
87
88
89def _handle_sigint(sig, frame):
90    _stop_event.set()
91    _global_server.stop(None)
92
93
94class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
95                                ):
96
97    def __init__(self):
98        super(_LoadBalancerStatsServicer).__init__()
99
100    def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest,
101                       context: grpc.ServicerContext
102                      ) -> messages_pb2.LoadBalancerStatsResponse:
103        logger.info("Received stats request.")
104        start = None
105        end = None
106        watcher = None
107        with _global_lock:
108            start = _global_rpc_id + 1
109            end = start + request.num_rpcs
110            watcher = _StatsWatcher(start, end)
111            _watchers.add(watcher)
112        response = watcher.await_rpc_stats_response(request.timeout_sec)
113        with _global_lock:
114            _watchers.remove(watcher)
115        logger.info("Returning stats response: {}".format(response))
116        return response
117
118
119def _start_rpc(request_id: int, stub: test_pb2_grpc.TestServiceStub,
120               timeout: float, futures: Mapping[int, grpc.Future]) -> None:
121    logger.info(f"Sending request to backend: {request_id}")
122    future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
123                                   timeout=timeout)
124    futures[request_id] = future
125
126
127def _on_rpc_done(rpc_id: int, future: grpc.Future,
128                 print_response: bool) -> None:
129    exception = future.exception()
130    hostname = ""
131    if exception is not None:
132        if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
133            logger.error(f"RPC {rpc_id} timed out")
134        else:
135            logger.error(exception)
136    else:
137        response = future.result()
138        logger.info(f"Got result {rpc_id}")
139        hostname = response.hostname
140        if print_response:
141            if future.code() == grpc.StatusCode.OK:
142                logger.info("Successful response.")
143            else:
144                logger.info(f"RPC failed: {call}")
145    with _global_lock:
146        for watcher in _watchers:
147            watcher.on_rpc_complete(rpc_id, hostname)
148
149
150def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
151                           print_response: bool) -> None:
152    logger.debug("Removing completed RPCs")
153    done = []
154    for future_id, future in futures.items():
155        if future.done():
156            _on_rpc_done(future_id, future, args.print_response)
157            done.append(future_id)
158    for rpc_id in done:
159        del futures[rpc_id]
160
161
162def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None:
163    logger.info("Cancelling all remaining RPCs")
164    for future in futures.values():
165        future.cancel()
166
167
168def _run_single_channel(args: argparse.Namespace):
169    global _global_rpc_id  # pylint: disable=global-statement
170    duration_per_query = 1.0 / float(args.qps)
171    with grpc.insecure_channel(args.server) as channel:
172        stub = test_pb2_grpc.TestServiceStub(channel)
173        futures: Dict[int, grpc.Future] = {}
174        while not _stop_event.is_set():
175            request_id = None
176            with _global_lock:
177                request_id = _global_rpc_id
178                _global_rpc_id += 1
179            start = time.time()
180            end = start + duration_per_query
181            _start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures)
182            _remove_completed_rpcs(futures, args.print_response)
183            logger.debug(f"Currently {len(futures)} in-flight RPCs")
184            now = time.time()
185            while now < end:
186                time.sleep(end - now)
187                now = time.time()
188        _cancel_all_rpcs(futures)
189
190
191def _run(args: argparse.Namespace) -> None:
192    logger.info("Starting python xDS Interop Client.")
193    global _global_server  # pylint: disable=global-statement
194    channel_threads: List[threading.Thread] = []
195    for i in range(args.num_channels):
196        thread = threading.Thread(target=_run_single_channel, args=(args,))
197        thread.start()
198        channel_threads.append(thread)
199    _global_server = grpc.server(futures.ThreadPoolExecutor())
200    _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
201    test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
202        _LoadBalancerStatsServicer(), _global_server)
203    _global_server.start()
204    _global_server.wait_for_termination()
205    for i in range(args.num_channels):
206        thread.join()
207
208
209if __name__ == "__main__":
210    parser = argparse.ArgumentParser(
211        description='Run Python XDS interop client.')
212    parser.add_argument(
213        "--num_channels",
214        default=1,
215        type=int,
216        help="The number of channels from which to send requests.")
217    parser.add_argument("--print_response",
218                        default=False,
219                        action="store_true",
220                        help="Write RPC response to STDOUT.")
221    parser.add_argument(
222        "--qps",
223        default=1,
224        type=int,
225        help="The number of queries to send from each channel per second.")
226    parser.add_argument("--rpc_timeout_sec",
227                        default=30,
228                        type=int,
229                        help="The per-RPC timeout in seconds.")
230    parser.add_argument("--server",
231                        default="localhost:50051",
232                        help="The address of the server.")
233    parser.add_argument(
234        "--stats_port",
235        default=50052,
236        type=int,
237        help="The port on which to expose the peer distribution stats service.")
238    parser.add_argument('--verbose',
239                        help='verbose log output',
240                        default=False,
241                        action='store_true')
242    parser.add_argument("--log_file",
243                        default=None,
244                        type=str,
245                        help="A file to log to.")
246    args = parser.parse_args()
247    signal.signal(signal.SIGINT, _handle_sigint)
248    if args.verbose:
249        logger.setLevel(logging.DEBUG)
250    if args.log_file:
251        file_handler = logging.FileHandler(args.log_file, mode='a')
252        file_handler.setFormatter(formatter)
253        logger.addHandler(file_handler)
254    _run(args)
255