• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Entry point for running stress tests."""
15
16import argparse
17from concurrent import futures
18import threading
19
20import grpc
21from six.moves import queue
22from src.proto.grpc.testing import metrics_pb2_grpc
23from src.proto.grpc.testing import test_pb2_grpc
24
25from tests.interop import methods
26from tests.interop import resources
27from tests.qps import histogram
28from tests.stress import metrics_server
29from tests.stress import test_runner
30
31
32def _args():
33    parser = argparse.ArgumentParser(
34        description='gRPC Python stress test client')
35    parser.add_argument(
36        '--server_addresses',
37        help='comma separated list of hostname:port to run servers on',
38        default='localhost:8080',
39        type=str)
40    parser.add_argument(
41        '--test_cases',
42        help='comma separated list of testcase:weighting of tests to run',
43        default='large_unary:100',
44        type=str)
45    parser.add_argument('--test_duration_secs',
46                        help='number of seconds to run the stress test',
47                        default=-1,
48                        type=int)
49    parser.add_argument('--num_channels_per_server',
50                        help='number of channels per server',
51                        default=1,
52                        type=int)
53    parser.add_argument('--num_stubs_per_channel',
54                        help='number of stubs to create per channel',
55                        default=1,
56                        type=int)
57    parser.add_argument('--metrics_port',
58                        help='the port to listen for metrics requests on',
59                        default=8081,
60                        type=int)
61    parser.add_argument(
62        '--use_test_ca',
63        help='Whether to use our fake CA. Requires --use_tls=true',
64        default=False,
65        type=bool)
66    parser.add_argument('--use_tls',
67                        help='Whether to use TLS',
68                        default=False,
69                        type=bool)
70    parser.add_argument('--server_host_override',
71                        help='the server host to which to claim to connect',
72                        type=str)
73    return parser.parse_args()
74
75
76def _test_case_from_arg(test_case_arg):
77    for test_case in methods.TestCase:
78        if test_case_arg == test_case.value:
79            return test_case
80    else:
81        raise ValueError('No test case {}!'.format(test_case_arg))
82
83
84def _parse_weighted_test_cases(test_case_args):
85    weighted_test_cases = {}
86    for test_case_arg in test_case_args.split(','):
87        name, weight = test_case_arg.split(':', 1)
88        test_case = _test_case_from_arg(name)
89        weighted_test_cases[test_case] = int(weight)
90    return weighted_test_cases
91
92
93def _get_channel(target, args):
94    if args.use_tls:
95        if args.use_test_ca:
96            root_certificates = resources.test_root_certificates()
97        else:
98            root_certificates = None  # will load default roots.
99        channel_credentials = grpc.ssl_channel_credentials(
100            root_certificates=root_certificates)
101        options = ((
102            'grpc.ssl_target_name_override',
103            args.server_host_override,
104        ),)
105        channel = grpc.secure_channel(target,
106                                      channel_credentials,
107                                      options=options)
108    else:
109        channel = grpc.insecure_channel(target)
110
111    # waits for the channel to be ready before we start sending messages
112    grpc.channel_ready_future(channel).result()
113    return channel
114
115
116def run_test(args):
117    test_cases = _parse_weighted_test_cases(args.test_cases)
118    test_server_targets = args.server_addresses.split(',')
119    # Propagate any client exceptions with a queue
120    exception_queue = queue.Queue()
121    stop_event = threading.Event()
122    hist = histogram.Histogram(1, 1)
123    runners = []
124
125    server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
126    metrics_pb2_grpc.add_MetricsServiceServicer_to_server(
127        metrics_server.MetricsServer(hist), server)
128    server.add_insecure_port('[::]:{}'.format(args.metrics_port))
129    server.start()
130
131    for test_server_target in test_server_targets:
132        for _ in range(args.num_channels_per_server):
133            channel = _get_channel(test_server_target, args)
134            for _ in range(args.num_stubs_per_channel):
135                stub = test_pb2_grpc.TestServiceStub(channel)
136                runner = test_runner.TestRunner(stub, test_cases, hist,
137                                                exception_queue, stop_event)
138                runners.append(runner)
139
140    for runner in runners:
141        runner.start()
142    try:
143        timeout_secs = args.test_duration_secs
144        if timeout_secs < 0:
145            timeout_secs = None
146        raise exception_queue.get(block=True, timeout=timeout_secs)
147    except queue.Empty:
148        # No exceptions thrown, success
149        pass
150    finally:
151        stop_event.set()
152        for runner in runners:
153            runner.join()
154        runner = None
155        server.stop(None)
156
157
158if __name__ == '__main__':
159    run_test(_args())
160