# Copyright 2016 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Entry point for running stress tests.""" from concurrent import futures import queue import threading from absl import app from absl.flags import argparse_flags import grpc from src.proto.grpc.testing import metrics_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc from tests.interop import methods from tests.interop import resources from tests.qps import histogram from tests.stress import metrics_server from tests.stress import test_runner def _args(argv): parser = argparse_flags.ArgumentParser() parser.add_argument( "--server_addresses", help="comma separated list of hostname:port to run servers on", default="localhost:8080", type=str, ) parser.add_argument( "--test_cases", help="comma separated list of testcase:weighting of tests to run", default="large_unary:100", type=str, ) parser.add_argument( "--test_duration_secs", help="number of seconds to run the stress test", default=-1, type=int, ) parser.add_argument( "--num_channels_per_server", help="number of channels per server", default=1, type=int, ) parser.add_argument( "--num_stubs_per_channel", help="number of stubs to create per channel", default=1, type=int, ) parser.add_argument( "--metrics_port", help="the port to listen for metrics requests on", default=8081, type=int, ) parser.add_argument( "--use_test_ca", help="Whether to use our fake CA. Requires --use_tls=true", default=False, type=bool, ) parser.add_argument( "--use_tls", help="Whether to use TLS", default=False, type=bool ) parser.add_argument( "--server_host_override", help="the server host to which to claim to connect", type=str, ) return parser.parse_args(argv[1:]) def _test_case_from_arg(test_case_arg): for test_case in methods.TestCase: if test_case_arg == test_case.value: return test_case else: raise ValueError("No test case {}!".format(test_case_arg)) def _parse_weighted_test_cases(test_case_args): weighted_test_cases = {} for test_case_arg in test_case_args.split(","): name, weight = test_case_arg.split(":", 1) test_case = _test_case_from_arg(name) weighted_test_cases[test_case] = int(weight) return weighted_test_cases def _get_channel(target, args): if args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() else: root_certificates = None # will load default roots. channel_credentials = grpc.ssl_channel_credentials( root_certificates=root_certificates ) options = ( ( "grpc.ssl_target_name_override", args.server_host_override, ), ) channel = grpc.secure_channel( target, channel_credentials, options=options ) else: channel = grpc.insecure_channel(target) # waits for the channel to be ready before we start sending messages grpc.channel_ready_future(channel).result() return channel def run_test(args): test_cases = _parse_weighted_test_cases(args.test_cases) test_server_targets = args.server_addresses.split(",") # Propagate any client exceptions with a queue exception_queue = queue.Queue() stop_event = threading.Event() hist = histogram.Histogram(1, 1) runners = [] server = grpc.server(futures.ThreadPoolExecutor(max_workers=25)) metrics_pb2_grpc.add_MetricsServiceServicer_to_server( metrics_server.MetricsServer(hist), server ) server.add_insecure_port("[::]:{}".format(args.metrics_port)) server.start() for test_server_target in test_server_targets: for _ in range(args.num_channels_per_server): channel = _get_channel(test_server_target, args) for _ in range(args.num_stubs_per_channel): stub = test_pb2_grpc.TestServiceStub(channel) runner = test_runner.TestRunner( stub, test_cases, hist, exception_queue, stop_event ) runners.append(runner) for runner in runners: runner.start() try: timeout_secs = args.test_duration_secs if timeout_secs < 0: timeout_secs = None raise exception_queue.get(block=True, timeout=timeout_secs) except queue.Empty: # No exceptions thrown, success pass finally: stop_event.set() for runner in runners: runner.join() runner = None server.stop(None) if __name__ == "__main__": app.run(run_test, flags_parser=_args)