1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import multiprocessing 16import random 17import threading 18import time 19 20from concurrent import futures 21import grpc 22from src.proto.grpc.testing import control_pb2 23from src.proto.grpc.testing import benchmark_service_pb2_grpc 24from src.proto.grpc.testing import worker_service_pb2_grpc 25from src.proto.grpc.testing import stats_pb2 26 27from tests.qps import benchmark_client 28from tests.qps import benchmark_server 29from tests.qps import client_runner 30from tests.qps import histogram 31from tests.unit import resources 32from tests.unit import test_common 33 34 35class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): 36 """Python Worker Server implementation.""" 37 38 def __init__(self, server_port=None): 39 self._quit_event = threading.Event() 40 self._server_port = server_port 41 42 def RunServer(self, request_iterator, context): 43 config = next(request_iterator).setup #pylint: disable=stop-iteration-return 44 server, port = self._create_server(config) 45 cores = multiprocessing.cpu_count() 46 server.start() 47 start_time = time.time() 48 yield self._get_server_status(start_time, start_time, port, cores) 49 50 for request in request_iterator: 51 end_time = time.time() 52 status = self._get_server_status(start_time, end_time, port, cores) 53 if request.mark.reset: 54 start_time = end_time 55 yield status 56 server.stop(None) 57 58 def _get_server_status(self, start_time, end_time, port, cores): 59 end_time = time.time() 60 elapsed_time = end_time - start_time 61 stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, 62 time_user=elapsed_time, 63 time_system=elapsed_time) 64 return control_pb2.ServerStatus(stats=stats, port=port, cores=cores) 65 66 def _create_server(self, config): 67 if config.async_server_threads == 0: 68 # This is the default concurrent.futures thread pool size, but 69 # None doesn't seem to work 70 server_threads = multiprocessing.cpu_count() * 5 71 else: 72 server_threads = config.async_server_threads 73 server = test_common.test_server(max_workers=server_threads) 74 if config.server_type == control_pb2.ASYNC_SERVER: 75 servicer = benchmark_server.BenchmarkServer() 76 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 77 servicer, server) 78 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 79 resp_size = config.payload_config.bytebuf_params.resp_size 80 servicer = benchmark_server.GenericBenchmarkServer(resp_size) 81 method_implementations = { 82 'StreamingCall': 83 grpc.stream_stream_rpc_method_handler(servicer.StreamingCall 84 ), 85 'UnaryCall': 86 grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), 87 } 88 handler = grpc.method_handlers_generic_handler( 89 'grpc.testing.BenchmarkService', method_implementations) 90 server.add_generic_rpc_handlers((handler,)) 91 else: 92 raise Exception('Unsupported server type {}'.format( 93 config.server_type)) 94 95 if self._server_port is not None and config.port == 0: 96 server_port = self._server_port 97 else: 98 server_port = config.port 99 100 if config.HasField('security_params'): # Use SSL 101 server_creds = grpc.ssl_server_credentials( 102 ((resources.private_key(), resources.certificate_chain()),)) 103 port = server.add_secure_port('[::]:{}'.format(server_port), 104 server_creds) 105 else: 106 port = server.add_insecure_port('[::]:{}'.format(server_port)) 107 108 return (server, port) 109 110 def RunClient(self, request_iterator, context): 111 config = next(request_iterator).setup #pylint: disable=stop-iteration-return 112 client_runners = [] 113 qps_data = histogram.Histogram(config.histogram_params.resolution, 114 config.histogram_params.max_possible) 115 start_time = time.time() 116 117 # Create a client for each channel 118 for i in range(config.client_channels): 119 server = config.server_targets[i % len(config.server_targets)] 120 runner = self._create_client_runner(server, config, qps_data) 121 client_runners.append(runner) 122 runner.start() 123 124 end_time = time.time() 125 yield self._get_client_status(start_time, end_time, qps_data) 126 127 # Respond to stat requests 128 for request in request_iterator: 129 end_time = time.time() 130 status = self._get_client_status(start_time, end_time, qps_data) 131 if request.mark.reset: 132 qps_data.reset() 133 start_time = time.time() 134 yield status 135 136 # Cleanup the clients 137 for runner in client_runners: 138 runner.stop() 139 140 def _get_client_status(self, start_time, end_time, qps_data): 141 latencies = qps_data.get_data() 142 end_time = time.time() 143 elapsed_time = end_time - start_time 144 stats = stats_pb2.ClientStats(latencies=latencies, 145 time_elapsed=elapsed_time, 146 time_user=elapsed_time, 147 time_system=elapsed_time) 148 return control_pb2.ClientStatus(stats=stats) 149 150 def _create_client_runner(self, server, config, qps_data): 151 no_ping_pong = False 152 if config.client_type == control_pb2.SYNC_CLIENT: 153 if config.rpc_type == control_pb2.UNARY: 154 client = benchmark_client.UnarySyncBenchmarkClient( 155 server, config, qps_data) 156 elif config.rpc_type == control_pb2.STREAMING: 157 client = benchmark_client.StreamingSyncBenchmarkClient( 158 server, config, qps_data) 159 elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER: 160 no_ping_pong = True 161 client = benchmark_client.ServerStreamingSyncBenchmarkClient( 162 server, config, qps_data) 163 elif config.client_type == control_pb2.ASYNC_CLIENT: 164 if config.rpc_type == control_pb2.UNARY: 165 client = benchmark_client.UnaryAsyncBenchmarkClient( 166 server, config, qps_data) 167 else: 168 raise Exception('Async streaming client not supported') 169 else: 170 raise Exception('Unsupported client type {}'.format( 171 config.client_type)) 172 173 # In multi-channel tests, we split the load across all channels 174 load_factor = float(config.client_channels) 175 if config.load_params.WhichOneof('load') == 'closed_loop': 176 runner = client_runner.ClosedLoopClientRunner( 177 client, config.outstanding_rpcs_per_channel, no_ping_pong) 178 else: # Open loop Poisson 179 alpha = config.load_params.poisson.offered_load / load_factor 180 181 def poisson(): 182 while True: 183 yield random.expovariate(alpha) 184 185 runner = client_runner.OpenLoopClientRunner(client, poisson()) 186 187 return runner 188 189 def CoreCount(self, request, context): 190 return control_pb2.CoreResponse(cores=multiprocessing.cpu_count()) 191 192 def QuitWorker(self, request, context): 193 self._quit_event.set() 194 return control_pb2.Void() 195 196 def wait_for_quit(self): 197 self._quit_event.wait() 198