1# Copyright 2020 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import argparse 16import collections 17import datetime 18import logging 19import signal 20import threading 21import time 22import sys 23 24from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple 25import collections 26 27from concurrent import futures 28 29import grpc 30 31from src.proto.grpc.testing import test_pb2 32from src.proto.grpc.testing import test_pb2_grpc 33from src.proto.grpc.testing import messages_pb2 34from src.proto.grpc.testing import empty_pb2 35 36logger = logging.getLogger() 37console_handler = logging.StreamHandler() 38formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') 39console_handler.setFormatter(formatter) 40logger.addHandler(console_handler) 41 42_SUPPORTED_METHODS = ( 43 "UnaryCall", 44 "EmptyCall", 45) 46 47_METHOD_CAMEL_TO_CAPS_SNAKE = { 48 "UnaryCall": "UNARY_CALL", 49 "EmptyCall": "EMPTY_CALL", 50} 51 52_METHOD_STR_TO_ENUM = { 53 "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL, 54 "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL, 55} 56 57_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()} 58 59PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] 60 61_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500) 62 63 64class _StatsWatcher: 65 _start: int 66 _end: int 67 _rpcs_needed: int 68 _rpcs_by_peer: DefaultDict[str, int] 69 _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]] 70 _no_remote_peer: int 71 _lock: threading.Lock 72 _condition: threading.Condition 73 74 def __init__(self, start: int, end: int): 75 self._start = start 76 self._end = end 77 self._rpcs_needed = end - start 78 self._rpcs_by_peer = collections.defaultdict(int) 79 self._rpcs_by_method = collections.defaultdict( 80 lambda: collections.defaultdict(int)) 81 self._condition = threading.Condition() 82 self._no_remote_peer = 0 83 84 def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: 85 """Records statistics for a single RPC.""" 86 if self._start <= request_id < self._end: 87 with self._condition: 88 if not peer: 89 self._no_remote_peer += 1 90 else: 91 self._rpcs_by_peer[peer] += 1 92 self._rpcs_by_method[method][peer] += 1 93 self._rpcs_needed -= 1 94 self._condition.notify() 95 96 def await_rpc_stats_response( 97 self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse: 98 """Blocks until a full response has been collected.""" 99 with self._condition: 100 self._condition.wait_for(lambda: not self._rpcs_needed, 101 timeout=float(timeout_sec)) 102 response = messages_pb2.LoadBalancerStatsResponse() 103 for peer, count in self._rpcs_by_peer.items(): 104 response.rpcs_by_peer[peer] = count 105 for method, count_by_peer in self._rpcs_by_method.items(): 106 for peer, count in count_by_peer.items(): 107 response.rpcs_by_method[method].rpcs_by_peer[peer] = count 108 response.num_failures = self._no_remote_peer + self._rpcs_needed 109 return response 110 111 112_global_lock = threading.Lock() 113_stop_event = threading.Event() 114_global_rpc_id: int = 0 115_watchers: Set[_StatsWatcher] = set() 116_global_server = None 117_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int) 118_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int) 119_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int) 120 121# Mapping[method, Mapping[status_code, count]] 122_global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict( 123 lambda: collections.defaultdict(int)) 124 125 126def _handle_sigint(sig, frame) -> None: 127 _stop_event.set() 128 _global_server.stop(None) 129 130 131class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer 132 ): 133 134 def __init__(self): 135 super(_LoadBalancerStatsServicer).__init__() 136 137 def GetClientStats( 138 self, request: messages_pb2.LoadBalancerStatsRequest, 139 context: grpc.ServicerContext 140 ) -> messages_pb2.LoadBalancerStatsResponse: 141 logger.info("Received stats request.") 142 start = None 143 end = None 144 watcher = None 145 with _global_lock: 146 start = _global_rpc_id + 1 147 end = start + request.num_rpcs 148 watcher = _StatsWatcher(start, end) 149 _watchers.add(watcher) 150 response = watcher.await_rpc_stats_response(request.timeout_sec) 151 with _global_lock: 152 _watchers.remove(watcher) 153 logger.info("Returning stats response: %s", response) 154 return response 155 156 def GetClientAccumulatedStats( 157 self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest, 158 context: grpc.ServicerContext 159 ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse: 160 logger.info("Received cumulative stats request.") 161 response = messages_pb2.LoadBalancerAccumulatedStatsResponse() 162 with _global_lock: 163 for method in _SUPPORTED_METHODS: 164 caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method] 165 response.num_rpcs_started_by_method[ 166 caps_method] = _global_rpcs_started[method] 167 response.num_rpcs_succeeded_by_method[ 168 caps_method] = _global_rpcs_succeeded[method] 169 response.num_rpcs_failed_by_method[ 170 caps_method] = _global_rpcs_failed[method] 171 response.stats_per_method[ 172 caps_method].rpcs_started = _global_rpcs_started[method] 173 for code, count in _global_rpc_statuses[method].items(): 174 response.stats_per_method[caps_method].result[code] = count 175 logger.info("Returning cumulative stats response.") 176 return response 177 178 179def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], 180 request_id: int, stub: test_pb2_grpc.TestServiceStub, 181 timeout: float, futures: Mapping[int, Tuple[grpc.Future, 182 str]]) -> None: 183 logger.info(f"Sending {method} request to backend: {request_id}") 184 if method == "UnaryCall": 185 future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), 186 metadata=metadata, 187 timeout=timeout) 188 elif method == "EmptyCall": 189 future = stub.EmptyCall.future(empty_pb2.Empty(), 190 metadata=metadata, 191 timeout=timeout) 192 else: 193 raise ValueError(f"Unrecognized method '{method}'.") 194 futures[request_id] = (future, method) 195 196 197def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, 198 print_response: bool) -> None: 199 exception = future.exception() 200 hostname = "" 201 _global_rpc_statuses[method][future.code().value[0]] += 1 202 if exception is not None: 203 with _global_lock: 204 _global_rpcs_failed[method] += 1 205 if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED: 206 logger.error(f"RPC {rpc_id} timed out") 207 else: 208 logger.error(exception) 209 else: 210 response = future.result() 211 hostname = None 212 for metadatum in future.initial_metadata(): 213 if metadatum[0] == "hostname": 214 hostname = metadatum[1] 215 break 216 else: 217 hostname = response.hostname 218 if future.code() == grpc.StatusCode.OK: 219 with _global_lock: 220 _global_rpcs_succeeded[method] += 1 221 else: 222 with _global_lock: 223 _global_rpcs_failed[method] += 1 224 if print_response: 225 if future.code() == grpc.StatusCode.OK: 226 logger.info("Successful response.") 227 else: 228 logger.info(f"RPC failed: {call}") 229 with _global_lock: 230 for watcher in _watchers: 231 watcher.on_rpc_complete(rpc_id, hostname, method) 232 233 234def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], 235 print_response: bool) -> None: 236 logger.debug("Removing completed RPCs") 237 done = [] 238 for future_id, (future, method) in futures.items(): 239 if future.done(): 240 _on_rpc_done(future_id, future, method, args.print_response) 241 done.append(future_id) 242 for rpc_id in done: 243 del futures[rpc_id] 244 245 246def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: 247 logger.info("Cancelling all remaining RPCs") 248 for future, _ in futures.values(): 249 future.cancel() 250 251 252class _ChannelConfiguration: 253 """Configuration for a single client channel. 254 255 Instances of this class are meant to be dealt with as PODs. That is, 256 data member should be accessed directly. This class is not thread-safe. 257 When accessing any of its members, the lock member should be held. 258 """ 259 260 def __init__(self, method: str, metadata: Sequence[Tuple[str, 261 str]], qps: int, 262 server: str, rpc_timeout_sec: int, print_response: bool): 263 # condition is signalled when a change is made to the config. 264 self.condition = threading.Condition() 265 266 self.method = method 267 self.metadata = metadata 268 self.qps = qps 269 self.server = server 270 self.rpc_timeout_sec = rpc_timeout_sec 271 self.print_response = print_response 272 273 274def _run_single_channel(config: _ChannelConfiguration) -> None: 275 global _global_rpc_id # pylint: disable=global-statement 276 with config.condition: 277 server = config.server 278 with grpc.insecure_channel(server) as channel: 279 stub = test_pb2_grpc.TestServiceStub(channel) 280 futures: Dict[int, Tuple[grpc.Future, str]] = {} 281 while not _stop_event.is_set(): 282 with config.condition: 283 if config.qps == 0: 284 config.condition.wait( 285 timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds()) 286 continue 287 else: 288 duration_per_query = 1.0 / float(config.qps) 289 request_id = None 290 with _global_lock: 291 request_id = _global_rpc_id 292 _global_rpc_id += 1 293 _global_rpcs_started[config.method] += 1 294 start = time.time() 295 end = start + duration_per_query 296 with config.condition: 297 _start_rpc(config.method, config.metadata, request_id, stub, 298 float(config.rpc_timeout_sec), futures) 299 with config.condition: 300 _remove_completed_rpcs(futures, config.print_response) 301 logger.debug(f"Currently {len(futures)} in-flight RPCs") 302 now = time.time() 303 while now < end: 304 time.sleep(end - now) 305 now = time.time() 306 _cancel_all_rpcs(futures) 307 308 309class _XdsUpdateClientConfigureServicer( 310 test_pb2_grpc.XdsUpdateClientConfigureServiceServicer): 311 312 def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration], 313 qps: int): 314 super(_XdsUpdateClientConfigureServicer).__init__() 315 self._per_method_configs = per_method_configs 316 self._qps = qps 317 318 def Configure( 319 self, request: messages_pb2.ClientConfigureRequest, 320 context: grpc.ServicerContext 321 ) -> messages_pb2.ClientConfigureResponse: 322 logger.info("Received Configure RPC: %s", request) 323 method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types) 324 for method in _SUPPORTED_METHODS: 325 method_enum = _METHOD_STR_TO_ENUM[method] 326 channel_config = self._per_method_configs[method] 327 if method in method_strs: 328 qps = self._qps 329 metadata = ((md.key, md.value) 330 for md in request.metadata 331 if md.type == method_enum) 332 # For backward compatibility, do not change timeout when we 333 # receive a default value timeout. 334 if request.timeout_sec == 0: 335 timeout_sec = channel_config.rpc_timeout_sec 336 else: 337 timeout_sec = request.timeout_sec 338 else: 339 qps = 0 340 metadata = () 341 # Leave timeout unchanged for backward compatibility. 342 timeout_sec = channel_config.rpc_timeout_sec 343 with channel_config.condition: 344 channel_config.qps = qps 345 channel_config.metadata = list(metadata) 346 channel_config.rpc_timeout_sec = timeout_sec 347 channel_config.condition.notify_all() 348 return messages_pb2.ClientConfigureResponse() 349 350 351class _MethodHandle: 352 """An object grouping together threads driving RPCs for a method.""" 353 354 _channel_threads: List[threading.Thread] 355 356 def __init__(self, num_channels: int, 357 channel_config: _ChannelConfiguration): 358 """Creates and starts a group of threads running the indicated method.""" 359 self._channel_threads = [] 360 for i in range(num_channels): 361 thread = threading.Thread(target=_run_single_channel, 362 args=(channel_config,)) 363 thread.start() 364 self._channel_threads.append(thread) 365 366 def stop(self) -> None: 367 """Joins all threads referenced by the handle.""" 368 for channel_thread in self._channel_threads: 369 channel_thread.join() 370 371 372def _run(args: argparse.Namespace, methods: Sequence[str], 373 per_method_metadata: PerMethodMetadataType) -> None: 374 logger.info("Starting python xDS Interop Client.") 375 global _global_server # pylint: disable=global-statement 376 method_handles = [] 377 channel_configs = {} 378 for method in _SUPPORTED_METHODS: 379 if method in methods: 380 qps = args.qps 381 else: 382 qps = 0 383 channel_config = _ChannelConfiguration( 384 method, per_method_metadata.get(method, []), qps, args.server, 385 args.rpc_timeout_sec, args.print_response) 386 channel_configs[method] = channel_config 387 method_handles.append(_MethodHandle(args.num_channels, channel_config)) 388 _global_server = grpc.server(futures.ThreadPoolExecutor()) 389 _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") 390 test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( 391 _LoadBalancerStatsServicer(), _global_server) 392 test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server( 393 _XdsUpdateClientConfigureServicer(channel_configs, args.qps), 394 _global_server) 395 _global_server.start() 396 _global_server.wait_for_termination() 397 for method_handle in method_handles: 398 method_handle.stop() 399 400 401def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: 402 metadata = metadata_arg.split(",") if args.metadata else [] 403 per_method_metadata = collections.defaultdict(list) 404 for metadatum in metadata: 405 elems = metadatum.split(":") 406 if len(elems) != 3: 407 raise ValueError( 408 f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'") 409 if elems[0] not in _SUPPORTED_METHODS: 410 raise ValueError(f"Unrecognized method '{elems[0]}'") 411 per_method_metadata[elems[0]].append((elems[1], elems[2])) 412 return per_method_metadata 413 414 415def parse_rpc_arg(rpc_arg: str) -> Sequence[str]: 416 methods = rpc_arg.split(",") 417 if set(methods) - set(_SUPPORTED_METHODS): 418 raise ValueError("--rpc supported methods: {}".format( 419 ", ".join(_SUPPORTED_METHODS))) 420 return methods 421 422 423if __name__ == "__main__": 424 parser = argparse.ArgumentParser( 425 description='Run Python XDS interop client.') 426 parser.add_argument( 427 "--num_channels", 428 default=1, 429 type=int, 430 help="The number of channels from which to send requests.") 431 parser.add_argument("--print_response", 432 default=False, 433 action="store_true", 434 help="Write RPC response to STDOUT.") 435 parser.add_argument( 436 "--qps", 437 default=1, 438 type=int, 439 help="The number of queries to send from each channel per second.") 440 parser.add_argument("--rpc_timeout_sec", 441 default=30, 442 type=int, 443 help="The per-RPC timeout in seconds.") 444 parser.add_argument("--server", 445 default="localhost:50051", 446 help="The address of the server.") 447 parser.add_argument( 448 "--stats_port", 449 default=50052, 450 type=int, 451 help="The port on which to expose the peer distribution stats service.") 452 parser.add_argument('--verbose', 453 help='verbose log output', 454 default=False, 455 action='store_true') 456 parser.add_argument("--log_file", 457 default=None, 458 type=str, 459 help="A file to log to.") 460 rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " 461 rpc_help += ", ".join(_SUPPORTED_METHODS) 462 rpc_help += "." 463 parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help) 464 metadata_help = ( 465 "A comma-delimited list of 3-tuples of the form " + 466 "METHOD:KEY:VALUE, e.g. " + 467 "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3") 468 parser.add_argument("--metadata", default="", type=str, help=metadata_help) 469 args = parser.parse_args() 470 signal.signal(signal.SIGINT, _handle_sigint) 471 if args.verbose: 472 logger.setLevel(logging.DEBUG) 473 if args.log_file: 474 file_handler = logging.FileHandler(args.log_file, mode='a') 475 file_handler.setFormatter(formatter) 476 logger.addHandler(file_handler) 477 _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata)) 478