1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import collections 17import logging 18import multiprocessing 19import os 20import sys 21import time 22from typing import Tuple 23 24import grpc 25from grpc.experimental import aio 26 27from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, 28 stats_pb2, worker_service_pb2_grpc) 29from tests.qps import histogram 30from tests.unit import resources 31from tests.unit.framework.common import get_socket 32from tests_aio.benchmark import benchmark_client, benchmark_servicer 33 34_NUM_CORES = multiprocessing.cpu_count() 35_WORKER_ENTRY_FILE = os.path.join( 36 os.path.split(os.path.abspath(__file__))[0], 'worker.py') 37 38_LOGGER = logging.getLogger(__name__) 39 40 41class _SubWorker( 42 collections.namedtuple('_SubWorker', 43 ['process', 'port', 'channel', 'stub'])): 44 """A data class that holds information about a child qps worker.""" 45 46 def _repr(self): 47 return f'<_SubWorker pid={self.process.pid} port={self.port}>' 48 49 def __repr__(self): 50 return self._repr() 51 52 def __str__(self): 53 return self._repr() 54 55 56def _get_server_status(start_time: float, end_time: float, 57 port: int) -> control_pb2.ServerStatus: 58 """Creates ServerStatus proto message.""" 59 end_time = time.monotonic() 60 elapsed_time = end_time - start_time 61 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 62 stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, 63 time_user=elapsed_time, 64 time_system=elapsed_time) 65 return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) 66 67 68def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: 69 """Creates a server object according to the ServerConfig.""" 70 channel_args = tuple( 71 (arg.name, 72 arg.str_value) if arg.HasField('str_value') else (arg.name, 73 int(arg.int_value)) 74 for arg in config.channel_args) 75 76 server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),)) 77 if config.server_type == control_pb2.ASYNC_SERVER: 78 servicer = benchmark_servicer.BenchmarkServicer() 79 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 80 servicer, server) 81 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 82 resp_size = config.payload_config.bytebuf_params.resp_size 83 servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) 84 method_implementations = { 85 'StreamingCall': 86 grpc.stream_stream_rpc_method_handler(servicer.StreamingCall), 87 'UnaryCall': 88 grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), 89 } 90 handler = grpc.method_handlers_generic_handler( 91 'grpc.testing.BenchmarkService', method_implementations) 92 server.add_generic_rpc_handlers((handler,)) 93 else: 94 raise NotImplementedError('Unsupported server type {}'.format( 95 config.server_type)) 96 97 if config.HasField('security_params'): # Use SSL 98 server_creds = grpc.ssl_server_credentials( 99 ((resources.private_key(), resources.certificate_chain()),)) 100 port = server.add_secure_port('[::]:{}'.format(config.port), 101 server_creds) 102 else: 103 port = server.add_insecure_port('[::]:{}'.format(config.port)) 104 105 return server, port 106 107 108def _get_client_status( 109 start_time: float, end_time: float, 110 qps_data: histogram.Histogram) -> control_pb2.ClientStatus: 111 """Creates ClientStatus proto message.""" 112 latencies = qps_data.get_data() 113 end_time = time.monotonic() 114 elapsed_time = end_time - start_time 115 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 116 stats = stats_pb2.ClientStats(latencies=latencies, 117 time_elapsed=elapsed_time, 118 time_user=elapsed_time, 119 time_system=elapsed_time) 120 return control_pb2.ClientStatus(stats=stats) 121 122 123def _create_client( 124 server: str, config: control_pb2.ClientConfig, 125 qps_data: histogram.Histogram) -> benchmark_client.BenchmarkClient: 126 """Creates a client object according to the ClientConfig.""" 127 if config.load_params.WhichOneof('load') != 'closed_loop': 128 raise NotImplementedError( 129 f'Unsupported load parameter {config.load_params}') 130 131 if config.client_type == control_pb2.ASYNC_CLIENT: 132 if config.rpc_type == control_pb2.UNARY: 133 client_type = benchmark_client.UnaryAsyncBenchmarkClient 134 elif config.rpc_type == control_pb2.STREAMING: 135 client_type = benchmark_client.StreamingAsyncBenchmarkClient 136 elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER: 137 client_type = benchmark_client.ServerStreamingAsyncBenchmarkClient 138 else: 139 raise NotImplementedError( 140 f'Unsupported rpc_type [{config.rpc_type}]') 141 else: 142 raise NotImplementedError( 143 f'Unsupported client type {config.client_type}') 144 145 return client_type(server, config, qps_data) 146 147 148def _pick_an_unused_port() -> int: 149 """Picks an unused TCP port.""" 150 _, port, sock = get_socket() 151 sock.close() 152 return port 153 154 155async def _create_sub_worker() -> _SubWorker: 156 """Creates a child qps worker as a subprocess.""" 157 port = _pick_an_unused_port() 158 159 _LOGGER.info('Creating sub worker at port [%d]...', port) 160 process = await asyncio.create_subprocess_exec(sys.executable, 161 _WORKER_ENTRY_FILE, 162 '--driver_port', str(port)) 163 _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, 164 process.pid) 165 channel = aio.insecure_channel(f'localhost:{port}') 166 _LOGGER.info('Waiting for sub worker at port [%d]', port) 167 await channel.channel_ready() 168 stub = worker_service_pb2_grpc.WorkerServiceStub(channel) 169 return _SubWorker( 170 process=process, 171 port=port, 172 channel=channel, 173 stub=stub, 174 ) 175 176 177class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): 178 """Python Worker Server implementation.""" 179 180 def __init__(self): 181 self._loop = asyncio.get_event_loop() 182 self._quit_event = asyncio.Event() 183 184 async def _run_single_server(self, config, request_iterator, context): 185 server, port = _create_server(config) 186 await server.start() 187 _LOGGER.info('Server started at port [%d]', port) 188 189 start_time = time.monotonic() 190 await context.write(_get_server_status(start_time, start_time, port)) 191 192 async for request in request_iterator: 193 end_time = time.monotonic() 194 status = _get_server_status(start_time, end_time, port) 195 if request.mark.reset: 196 start_time = end_time 197 await context.write(status) 198 await server.stop(None) 199 200 async def RunServer(self, request_iterator, context): 201 config_request = await context.read() 202 config = config_request.setup 203 _LOGGER.info('Received ServerConfig: %s', config) 204 205 if config.server_processes <= 0: 206 _LOGGER.info('Using server_processes == [%d]', _NUM_CORES) 207 config.server_processes = _NUM_CORES 208 209 if config.port == 0: 210 config.port = _pick_an_unused_port() 211 _LOGGER.info('Port picked [%d]', config.port) 212 213 if config.server_processes == 1: 214 # If server_processes == 1, start the server in this process. 215 await self._run_single_server(config, request_iterator, context) 216 else: 217 # If server_processes > 1, offload to other processes. 218 sub_workers = await asyncio.gather( 219 *[_create_sub_worker() for _ in range(config.server_processes)]) 220 221 calls = [worker.stub.RunServer() for worker in sub_workers] 222 223 config_request.setup.server_processes = 1 224 225 for call in calls: 226 await call.write(config_request) 227 # An empty status indicates the peer is ready 228 await call.read() 229 230 start_time = time.monotonic() 231 await context.write( 232 _get_server_status( 233 start_time, 234 start_time, 235 config.port, 236 )) 237 238 _LOGGER.info('Servers are ready to serve.') 239 240 async for request in request_iterator: 241 end_time = time.monotonic() 242 243 for call in calls: 244 await call.write(request) 245 # Reports from sub workers doesn't matter 246 await call.read() 247 248 status = _get_server_status( 249 start_time, 250 end_time, 251 config.port, 252 ) 253 if request.mark.reset: 254 start_time = end_time 255 await context.write(status) 256 257 for call in calls: 258 await call.done_writing() 259 260 for worker in sub_workers: 261 await worker.stub.QuitWorker(control_pb2.Void()) 262 await worker.channel.close() 263 _LOGGER.info('Waiting for [%s] to quit...', worker) 264 await worker.process.wait() 265 266 async def _run_single_client(self, config, request_iterator, context): 267 running_tasks = [] 268 qps_data = histogram.Histogram(config.histogram_params.resolution, 269 config.histogram_params.max_possible) 270 start_time = time.monotonic() 271 272 # Create a client for each channel as asyncio.Task 273 for i in range(config.client_channels): 274 server = config.server_targets[i % len(config.server_targets)] 275 client = _create_client(server, config, qps_data) 276 _LOGGER.info('Client created against server [%s]', server) 277 running_tasks.append(self._loop.create_task(client.run())) 278 279 end_time = time.monotonic() 280 await context.write(_get_client_status(start_time, end_time, qps_data)) 281 282 # Respond to stat requests 283 async for request in request_iterator: 284 end_time = time.monotonic() 285 status = _get_client_status(start_time, end_time, qps_data) 286 if request.mark.reset: 287 qps_data.reset() 288 start_time = time.monotonic() 289 await context.write(status) 290 291 # Cleanup the clients 292 for task in running_tasks: 293 task.cancel() 294 295 async def RunClient(self, request_iterator, context): 296 config_request = await context.read() 297 config = config_request.setup 298 _LOGGER.info('Received ClientConfig: %s', config) 299 300 if config.client_processes <= 0: 301 _LOGGER.info('client_processes can\'t be [%d]', 302 config.client_processes) 303 _LOGGER.info('Using client_processes == [%d]', _NUM_CORES) 304 config.client_processes = _NUM_CORES 305 306 if config.client_processes == 1: 307 # If client_processes == 1, run the benchmark in this process. 308 await self._run_single_client(config, request_iterator, context) 309 else: 310 # If client_processes > 1, offload the work to other processes. 311 sub_workers = await asyncio.gather( 312 *[_create_sub_worker() for _ in range(config.client_processes)]) 313 314 calls = [worker.stub.RunClient() for worker in sub_workers] 315 316 config_request.setup.client_processes = 1 317 318 for call in calls: 319 await call.write(config_request) 320 # An empty status indicates the peer is ready 321 await call.read() 322 323 start_time = time.monotonic() 324 result = histogram.Histogram(config.histogram_params.resolution, 325 config.histogram_params.max_possible) 326 end_time = time.monotonic() 327 await context.write(_get_client_status(start_time, end_time, 328 result)) 329 330 async for request in request_iterator: 331 end_time = time.monotonic() 332 333 for call in calls: 334 _LOGGER.debug('Fetching status...') 335 await call.write(request) 336 sub_status = await call.read() 337 result.merge(sub_status.stats.latencies) 338 _LOGGER.debug('Update from sub worker count=[%d]', 339 sub_status.stats.latencies.count) 340 341 status = _get_client_status(start_time, end_time, result) 342 if request.mark.reset: 343 result.reset() 344 start_time = time.monotonic() 345 _LOGGER.debug('Reporting count=[%d]', 346 status.stats.latencies.count) 347 await context.write(status) 348 349 for call in calls: 350 await call.done_writing() 351 352 for worker in sub_workers: 353 await worker.stub.QuitWorker(control_pb2.Void()) 354 await worker.channel.close() 355 _LOGGER.info('Waiting for sub worker [%s] to quit...', worker) 356 await worker.process.wait() 357 _LOGGER.info('Sub worker [%s] quit', worker) 358 359 @staticmethod 360 async def CoreCount(unused_request, unused_context): 361 return control_pb2.CoreResponse(cores=_NUM_CORES) 362 363 async def QuitWorker(self, unused_request, unused_context): 364 _LOGGER.info('QuitWorker command received.') 365 self._quit_event.set() 366 return control_pb2.Void() 367 368 async def wait_for_quit(self): 369 await self._quit_event.wait() 370