1# Copyright 2020 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import argparse 16import logging 17import signal 18import threading 19import time 20import sys 21 22from typing import DefaultDict, Dict, List, Mapping, Set 23import collections 24 25from concurrent import futures 26 27import grpc 28 29from src.proto.grpc.testing import test_pb2 30from src.proto.grpc.testing import test_pb2_grpc 31from src.proto.grpc.testing import messages_pb2 32from src.proto.grpc.testing import empty_pb2 33 34logger = logging.getLogger() 35console_handler = logging.StreamHandler() 36formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') 37console_handler.setFormatter(formatter) 38logger.addHandler(console_handler) 39 40 41class _StatsWatcher: 42 _start: int 43 _end: int 44 _rpcs_needed: int 45 _rpcs_by_peer: DefaultDict[str, int] 46 _no_remote_peer: int 47 _lock: threading.Lock 48 _condition: threading.Condition 49 50 def __init__(self, start: int, end: int): 51 self._start = start 52 self._end = end 53 self._rpcs_needed = end - start 54 self._rpcs_by_peer = collections.defaultdict(int) 55 self._condition = threading.Condition() 56 self._no_remote_peer = 0 57 58 def on_rpc_complete(self, request_id: int, peer: str) -> None: 59 """Records statistics for a single RPC.""" 60 if self._start <= request_id < self._end: 61 with self._condition: 62 if not peer: 63 self._no_remote_peer += 1 64 else: 65 self._rpcs_by_peer[peer] += 1 66 self._rpcs_needed -= 1 67 self._condition.notify() 68 69 def await_rpc_stats_response(self, timeout_sec: int 70 ) -> messages_pb2.LoadBalancerStatsResponse: 71 """Blocks until a full response has been collected.""" 72 with self._condition: 73 self._condition.wait_for(lambda: not self._rpcs_needed, 74 timeout=float(timeout_sec)) 75 response = messages_pb2.LoadBalancerStatsResponse() 76 for peer, count in self._rpcs_by_peer.items(): 77 response.rpcs_by_peer[peer] = count 78 response.num_failures = self._no_remote_peer + self._rpcs_needed 79 return response 80 81 82_global_lock = threading.Lock() 83_stop_event = threading.Event() 84_global_rpc_id: int = 0 85_watchers: Set[_StatsWatcher] = set() 86_global_server = None 87 88 89def _handle_sigint(sig, frame): 90 _stop_event.set() 91 _global_server.stop(None) 92 93 94class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer 95 ): 96 97 def __init__(self): 98 super(_LoadBalancerStatsServicer).__init__() 99 100 def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest, 101 context: grpc.ServicerContext 102 ) -> messages_pb2.LoadBalancerStatsResponse: 103 logger.info("Received stats request.") 104 start = None 105 end = None 106 watcher = None 107 with _global_lock: 108 start = _global_rpc_id + 1 109 end = start + request.num_rpcs 110 watcher = _StatsWatcher(start, end) 111 _watchers.add(watcher) 112 response = watcher.await_rpc_stats_response(request.timeout_sec) 113 with _global_lock: 114 _watchers.remove(watcher) 115 logger.info("Returning stats response: {}".format(response)) 116 return response 117 118 119def _start_rpc(request_id: int, stub: test_pb2_grpc.TestServiceStub, 120 timeout: float, futures: Mapping[int, grpc.Future]) -> None: 121 logger.info(f"Sending request to backend: {request_id}") 122 future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), 123 timeout=timeout) 124 futures[request_id] = future 125 126 127def _on_rpc_done(rpc_id: int, future: grpc.Future, 128 print_response: bool) -> None: 129 exception = future.exception() 130 hostname = "" 131 if exception is not None: 132 if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED: 133 logger.error(f"RPC {rpc_id} timed out") 134 else: 135 logger.error(exception) 136 else: 137 response = future.result() 138 logger.info(f"Got result {rpc_id}") 139 hostname = response.hostname 140 if print_response: 141 if future.code() == grpc.StatusCode.OK: 142 logger.info("Successful response.") 143 else: 144 logger.info(f"RPC failed: {call}") 145 with _global_lock: 146 for watcher in _watchers: 147 watcher.on_rpc_complete(rpc_id, hostname) 148 149 150def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], 151 print_response: bool) -> None: 152 logger.debug("Removing completed RPCs") 153 done = [] 154 for future_id, future in futures.items(): 155 if future.done(): 156 _on_rpc_done(future_id, future, args.print_response) 157 done.append(future_id) 158 for rpc_id in done: 159 del futures[rpc_id] 160 161 162def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None: 163 logger.info("Cancelling all remaining RPCs") 164 for future in futures.values(): 165 future.cancel() 166 167 168def _run_single_channel(args: argparse.Namespace): 169 global _global_rpc_id # pylint: disable=global-statement 170 duration_per_query = 1.0 / float(args.qps) 171 with grpc.insecure_channel(args.server) as channel: 172 stub = test_pb2_grpc.TestServiceStub(channel) 173 futures: Dict[int, grpc.Future] = {} 174 while not _stop_event.is_set(): 175 request_id = None 176 with _global_lock: 177 request_id = _global_rpc_id 178 _global_rpc_id += 1 179 start = time.time() 180 end = start + duration_per_query 181 _start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures) 182 _remove_completed_rpcs(futures, args.print_response) 183 logger.debug(f"Currently {len(futures)} in-flight RPCs") 184 now = time.time() 185 while now < end: 186 time.sleep(end - now) 187 now = time.time() 188 _cancel_all_rpcs(futures) 189 190 191def _run(args: argparse.Namespace) -> None: 192 logger.info("Starting python xDS Interop Client.") 193 global _global_server # pylint: disable=global-statement 194 channel_threads: List[threading.Thread] = [] 195 for i in range(args.num_channels): 196 thread = threading.Thread(target=_run_single_channel, args=(args,)) 197 thread.start() 198 channel_threads.append(thread) 199 _global_server = grpc.server(futures.ThreadPoolExecutor()) 200 _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") 201 test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( 202 _LoadBalancerStatsServicer(), _global_server) 203 _global_server.start() 204 _global_server.wait_for_termination() 205 for i in range(args.num_channels): 206 thread.join() 207 208 209if __name__ == "__main__": 210 parser = argparse.ArgumentParser( 211 description='Run Python XDS interop client.') 212 parser.add_argument( 213 "--num_channels", 214 default=1, 215 type=int, 216 help="The number of channels from which to send requests.") 217 parser.add_argument("--print_response", 218 default=False, 219 action="store_true", 220 help="Write RPC response to STDOUT.") 221 parser.add_argument( 222 "--qps", 223 default=1, 224 type=int, 225 help="The number of queries to send from each channel per second.") 226 parser.add_argument("--rpc_timeout_sec", 227 default=30, 228 type=int, 229 help="The per-RPC timeout in seconds.") 230 parser.add_argument("--server", 231 default="localhost:50051", 232 help="The address of the server.") 233 parser.add_argument( 234 "--stats_port", 235 default=50052, 236 type=int, 237 help="The port on which to expose the peer distribution stats service.") 238 parser.add_argument('--verbose', 239 help='verbose log output', 240 default=False, 241 action='store_true') 242 parser.add_argument("--log_file", 243 default=None, 244 type=str, 245 help="A file to log to.") 246 args = parser.parse_args() 247 signal.signal(signal.SIGINT, _handle_sigint) 248 if args.verbose: 249 logger.setLevel(logging.DEBUG) 250 if args.log_file: 251 file_handler = logging.FileHandler(args.log_file, mode='a') 252 file_handler.setFormatter(formatter) 253 logger.addHandler(file_handler) 254 _run(args) 255