1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import multiprocessing 16import random 17import threading 18import time 19 20from concurrent import futures 21import grpc 22from src.proto.grpc.testing import control_pb2 23from src.proto.grpc.testing import benchmark_service_pb2_grpc 24from src.proto.grpc.testing import worker_service_pb2_grpc 25from src.proto.grpc.testing import stats_pb2 26 27from tests.qps import benchmark_client 28from tests.qps import benchmark_server 29from tests.qps import client_runner 30from tests.qps import histogram 31from tests.unit import resources 32from tests.unit import test_common 33 34 35class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): 36 """Python Worker Server implementation.""" 37 38 def __init__(self): 39 self._quit_event = threading.Event() 40 41 def RunServer(self, request_iterator, context): 42 config = next(request_iterator).setup #pylint: disable=stop-iteration-return 43 server, port = self._create_server(config) 44 cores = multiprocessing.cpu_count() 45 server.start() 46 start_time = time.time() 47 yield self._get_server_status(start_time, start_time, port, cores) 48 49 for request in request_iterator: 50 end_time = time.time() 51 status = self._get_server_status(start_time, end_time, port, cores) 52 if request.mark.reset: 53 start_time = end_time 54 yield status 55 server.stop(None) 56 57 def _get_server_status(self, start_time, end_time, port, cores): 58 end_time = time.time() 59 elapsed_time = end_time - start_time 60 stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, 61 time_user=elapsed_time, 62 time_system=elapsed_time) 63 return control_pb2.ServerStatus(stats=stats, port=port, cores=cores) 64 65 def _create_server(self, config): 66 if config.async_server_threads == 0: 67 # This is the default concurrent.futures thread pool size, but 68 # None doesn't seem to work 69 server_threads = multiprocessing.cpu_count() * 5 70 else: 71 server_threads = config.async_server_threads 72 server = test_common.test_server(max_workers=server_threads) 73 if config.server_type == control_pb2.ASYNC_SERVER: 74 servicer = benchmark_server.BenchmarkServer() 75 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 76 servicer, server) 77 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 78 resp_size = config.payload_config.bytebuf_params.resp_size 79 servicer = benchmark_server.GenericBenchmarkServer(resp_size) 80 method_implementations = { 81 'StreamingCall': 82 grpc.stream_stream_rpc_method_handler(servicer.StreamingCall 83 ), 84 'UnaryCall': 85 grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), 86 } 87 handler = grpc.method_handlers_generic_handler( 88 'grpc.testing.BenchmarkService', method_implementations) 89 server.add_generic_rpc_handlers((handler,)) 90 else: 91 raise Exception('Unsupported server type {}'.format( 92 config.server_type)) 93 94 if config.HasField('security_params'): # Use SSL 95 server_creds = grpc.ssl_server_credentials( 96 ((resources.private_key(), resources.certificate_chain()),)) 97 port = server.add_secure_port('[::]:{}'.format(config.port), 98 server_creds) 99 else: 100 port = server.add_insecure_port('[::]:{}'.format(config.port)) 101 102 return (server, port) 103 104 def RunClient(self, request_iterator, context): 105 config = next(request_iterator).setup #pylint: disable=stop-iteration-return 106 client_runners = [] 107 qps_data = histogram.Histogram(config.histogram_params.resolution, 108 config.histogram_params.max_possible) 109 start_time = time.time() 110 111 # Create a client for each channel 112 for i in range(config.client_channels): 113 server = config.server_targets[i % len(config.server_targets)] 114 runner = self._create_client_runner(server, config, qps_data) 115 client_runners.append(runner) 116 runner.start() 117 118 end_time = time.time() 119 yield self._get_client_status(start_time, end_time, qps_data) 120 121 # Respond to stat requests 122 for request in request_iterator: 123 end_time = time.time() 124 status = self._get_client_status(start_time, end_time, qps_data) 125 if request.mark.reset: 126 qps_data.reset() 127 start_time = time.time() 128 yield status 129 130 # Cleanup the clients 131 for runner in client_runners: 132 runner.stop() 133 134 def _get_client_status(self, start_time, end_time, qps_data): 135 latencies = qps_data.get_data() 136 end_time = time.time() 137 elapsed_time = end_time - start_time 138 stats = stats_pb2.ClientStats(latencies=latencies, 139 time_elapsed=elapsed_time, 140 time_user=elapsed_time, 141 time_system=elapsed_time) 142 return control_pb2.ClientStatus(stats=stats) 143 144 def _create_client_runner(self, server, config, qps_data): 145 if config.client_type == control_pb2.SYNC_CLIENT: 146 if config.rpc_type == control_pb2.UNARY: 147 client = benchmark_client.UnarySyncBenchmarkClient( 148 server, config, qps_data) 149 elif config.rpc_type == control_pb2.STREAMING: 150 client = benchmark_client.StreamingSyncBenchmarkClient( 151 server, config, qps_data) 152 elif config.client_type == control_pb2.ASYNC_CLIENT: 153 if config.rpc_type == control_pb2.UNARY: 154 client = benchmark_client.UnaryAsyncBenchmarkClient( 155 server, config, qps_data) 156 else: 157 raise Exception('Async streaming client not supported') 158 else: 159 raise Exception('Unsupported client type {}'.format( 160 config.client_type)) 161 162 # In multi-channel tests, we split the load across all channels 163 load_factor = float(config.client_channels) 164 if config.load_params.WhichOneof('load') == 'closed_loop': 165 runner = client_runner.ClosedLoopClientRunner( 166 client, config.outstanding_rpcs_per_channel) 167 else: # Open loop Poisson 168 alpha = config.load_params.poisson.offered_load / load_factor 169 170 def poisson(): 171 while True: 172 yield random.expovariate(alpha) 173 174 runner = client_runner.OpenLoopClientRunner(client, poisson()) 175 176 return runner 177 178 def CoreCount(self, request, context): 179 return control_pb2.CoreResponse(cores=multiprocessing.cpu_count()) 180 181 def QuitWorker(self, request, context): 182 self._quit_event.set() 183 return control_pb2.Void() 184 185 def wait_for_quit(self): 186 self._quit_event.wait() 187