1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import collections 17import logging 18import multiprocessing 19import os 20import sys 21import time 22from typing import Tuple 23 24import grpc 25from grpc.experimental import aio 26 27from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, 28 stats_pb2, worker_service_pb2_grpc) 29from tests.qps import histogram 30from tests.unit import resources 31from tests.unit.framework.common import get_socket 32from tests_aio.benchmark import benchmark_client, benchmark_servicer 33 34_NUM_CORES = multiprocessing.cpu_count() 35_WORKER_ENTRY_FILE = os.path.join( 36 os.path.split(os.path.abspath(__file__))[0], 'worker.py') 37 38_LOGGER = logging.getLogger(__name__) 39 40 41class _SubWorker( 42 collections.namedtuple('_SubWorker', 43 ['process', 'port', 'channel', 'stub'])): 44 """A data class that holds information about a child qps worker.""" 45 46 def _repr(self): 47 return f'<_SubWorker pid={self.process.pid} port={self.port}>' 48 49 def __repr__(self): 50 return self._repr() 51 52 def __str__(self): 53 return self._repr() 54 55 56def _get_server_status(start_time: float, end_time: float, 57 port: int) -> control_pb2.ServerStatus: 58 """Creates ServerStatus proto message.""" 59 end_time = time.monotonic() 60 elapsed_time = end_time - start_time 61 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 62 stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, 63 time_user=elapsed_time, 64 time_system=elapsed_time) 65 return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) 66 67 68def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: 69 """Creates a server object according to the ServerConfig.""" 70 channel_args = tuple( 71 (arg.name, 72 arg.str_value) if arg.HasField('str_value') else (arg.name, 73 int(arg.int_value)) 74 for arg in config.channel_args) 75 76 server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),)) 77 if config.server_type == control_pb2.ASYNC_SERVER: 78 servicer = benchmark_servicer.BenchmarkServicer() 79 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 80 servicer, server) 81 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 82 resp_size = config.payload_config.bytebuf_params.resp_size 83 servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) 84 method_implementations = { 85 'StreamingCall': 86 grpc.stream_stream_rpc_method_handler(servicer.StreamingCall), 87 'UnaryCall': 88 grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), 89 } 90 handler = grpc.method_handlers_generic_handler( 91 'grpc.testing.BenchmarkService', method_implementations) 92 server.add_generic_rpc_handlers((handler,)) 93 else: 94 raise NotImplementedError('Unsupported server type {}'.format( 95 config.server_type)) 96 97 if config.HasField('security_params'): # Use SSL 98 server_creds = grpc.ssl_server_credentials( 99 ((resources.private_key(), resources.certificate_chain()),)) 100 port = server.add_secure_port('[::]:{}'.format(config.port), 101 server_creds) 102 else: 103 port = server.add_insecure_port('[::]:{}'.format(config.port)) 104 105 return server, port 106 107 108def _get_client_status(start_time: float, end_time: float, 109 qps_data: histogram.Histogram 110 ) -> control_pb2.ClientStatus: 111 """Creates ClientStatus proto message.""" 112 latencies = qps_data.get_data() 113 end_time = time.monotonic() 114 elapsed_time = end_time - start_time 115 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 116 stats = stats_pb2.ClientStats(latencies=latencies, 117 time_elapsed=elapsed_time, 118 time_user=elapsed_time, 119 time_system=elapsed_time) 120 return control_pb2.ClientStatus(stats=stats) 121 122 123def _create_client(server: str, config: control_pb2.ClientConfig, 124 qps_data: histogram.Histogram 125 ) -> benchmark_client.BenchmarkClient: 126 """Creates a client object according to the ClientConfig.""" 127 if config.load_params.WhichOneof('load') != 'closed_loop': 128 raise NotImplementedError( 129 f'Unsupported load parameter {config.load_params}') 130 131 if config.client_type == control_pb2.ASYNC_CLIENT: 132 if config.rpc_type == control_pb2.UNARY: 133 client_type = benchmark_client.UnaryAsyncBenchmarkClient 134 elif config.rpc_type == control_pb2.STREAMING: 135 client_type = benchmark_client.StreamingAsyncBenchmarkClient 136 else: 137 raise NotImplementedError( 138 f'Unsupported rpc_type [{config.rpc_type}]') 139 else: 140 raise NotImplementedError( 141 f'Unsupported client type {config.client_type}') 142 143 return client_type(server, config, qps_data) 144 145 146def _pick_an_unused_port() -> int: 147 """Picks an unused TCP port.""" 148 _, port, sock = get_socket() 149 sock.close() 150 return port 151 152 153async def _create_sub_worker() -> _SubWorker: 154 """Creates a child qps worker as a subprocess.""" 155 port = _pick_an_unused_port() 156 157 _LOGGER.info('Creating sub worker at port [%d]...', port) 158 process = await asyncio.create_subprocess_exec(sys.executable, 159 _WORKER_ENTRY_FILE, 160 '--driver_port', str(port)) 161 _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, 162 process.pid) 163 channel = aio.insecure_channel(f'localhost:{port}') 164 _LOGGER.info('Waiting for sub worker at port [%d]', port) 165 await channel.channel_ready() 166 stub = worker_service_pb2_grpc.WorkerServiceStub(channel) 167 return _SubWorker( 168 process=process, 169 port=port, 170 channel=channel, 171 stub=stub, 172 ) 173 174 175class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): 176 """Python Worker Server implementation.""" 177 178 def __init__(self): 179 self._loop = asyncio.get_event_loop() 180 self._quit_event = asyncio.Event() 181 182 async def _run_single_server(self, config, request_iterator, context): 183 server, port = _create_server(config) 184 await server.start() 185 _LOGGER.info('Server started at port [%d]', port) 186 187 start_time = time.monotonic() 188 await context.write(_get_server_status(start_time, start_time, port)) 189 190 async for request in request_iterator: 191 end_time = time.monotonic() 192 status = _get_server_status(start_time, end_time, port) 193 if request.mark.reset: 194 start_time = end_time 195 await context.write(status) 196 await server.stop(None) 197 198 async def RunServer(self, request_iterator, context): 199 config_request = await context.read() 200 config = config_request.setup 201 _LOGGER.info('Received ServerConfig: %s', config) 202 203 if config.server_processes <= 0: 204 _LOGGER.info('Using server_processes == [%d]', _NUM_CORES) 205 config.server_processes = _NUM_CORES 206 207 if config.port == 0: 208 config.port = _pick_an_unused_port() 209 _LOGGER.info('Port picked [%d]', config.port) 210 211 if config.server_processes == 1: 212 # If server_processes == 1, start the server in this process. 213 await self._run_single_server(config, request_iterator, context) 214 else: 215 # If server_processes > 1, offload to other processes. 216 sub_workers = await asyncio.gather(*( 217 _create_sub_worker() for _ in range(config.server_processes))) 218 219 calls = [worker.stub.RunServer() for worker in sub_workers] 220 221 config_request.setup.server_processes = 1 222 223 for call in calls: 224 await call.write(config_request) 225 # An empty status indicates the peer is ready 226 await call.read() 227 228 start_time = time.monotonic() 229 await context.write( 230 _get_server_status( 231 start_time, 232 start_time, 233 config.port, 234 )) 235 236 _LOGGER.info('Servers are ready to serve.') 237 238 async for request in request_iterator: 239 end_time = time.monotonic() 240 241 for call in calls: 242 await call.write(request) 243 # Reports from sub workers doesn't matter 244 await call.read() 245 246 status = _get_server_status( 247 start_time, 248 end_time, 249 config.port, 250 ) 251 if request.mark.reset: 252 start_time = end_time 253 await context.write(status) 254 255 for call in calls: 256 await call.done_writing() 257 258 for worker in sub_workers: 259 await worker.stub.QuitWorker(control_pb2.Void()) 260 await worker.channel.close() 261 _LOGGER.info('Waiting for [%s] to quit...', worker) 262 await worker.process.wait() 263 264 async def _run_single_client(self, config, request_iterator, context): 265 running_tasks = [] 266 qps_data = histogram.Histogram(config.histogram_params.resolution, 267 config.histogram_params.max_possible) 268 start_time = time.monotonic() 269 270 # Create a client for each channel as asyncio.Task 271 for i in range(config.client_channels): 272 server = config.server_targets[i % len(config.server_targets)] 273 client = _create_client(server, config, qps_data) 274 _LOGGER.info('Client created against server [%s]', server) 275 running_tasks.append(self._loop.create_task(client.run())) 276 277 end_time = time.monotonic() 278 await context.write(_get_client_status(start_time, end_time, qps_data)) 279 280 # Respond to stat requests 281 async for request in request_iterator: 282 end_time = time.monotonic() 283 status = _get_client_status(start_time, end_time, qps_data) 284 if request.mark.reset: 285 qps_data.reset() 286 start_time = time.monotonic() 287 await context.write(status) 288 289 # Cleanup the clients 290 for task in running_tasks: 291 task.cancel() 292 293 async def RunClient(self, request_iterator, context): 294 config_request = await context.read() 295 config = config_request.setup 296 _LOGGER.info('Received ClientConfig: %s', config) 297 298 if config.client_processes <= 0: 299 _LOGGER.info('client_processes can\'t be [%d]', 300 config.client_processes) 301 _LOGGER.info('Using client_processes == [%d]', _NUM_CORES) 302 config.client_processes = _NUM_CORES 303 304 if config.client_processes == 1: 305 # If client_processes == 1, run the benchmark in this process. 306 await self._run_single_client(config, request_iterator, context) 307 else: 308 # If client_processes > 1, offload the work to other processes. 309 sub_workers = await asyncio.gather(*( 310 _create_sub_worker() for _ in range(config.client_processes))) 311 312 calls = [worker.stub.RunClient() for worker in sub_workers] 313 314 config_request.setup.client_processes = 1 315 316 for call in calls: 317 await call.write(config_request) 318 # An empty status indicates the peer is ready 319 await call.read() 320 321 start_time = time.monotonic() 322 result = histogram.Histogram(config.histogram_params.resolution, 323 config.histogram_params.max_possible) 324 end_time = time.monotonic() 325 await context.write(_get_client_status(start_time, end_time, 326 result)) 327 328 async for request in request_iterator: 329 end_time = time.monotonic() 330 331 for call in calls: 332 _LOGGER.debug('Fetching status...') 333 await call.write(request) 334 sub_status = await call.read() 335 result.merge(sub_status.stats.latencies) 336 _LOGGER.debug('Update from sub worker count=[%d]', 337 sub_status.stats.latencies.count) 338 339 status = _get_client_status(start_time, end_time, result) 340 if request.mark.reset: 341 result.reset() 342 start_time = time.monotonic() 343 _LOGGER.debug('Reporting count=[%d]', 344 status.stats.latencies.count) 345 await context.write(status) 346 347 for call in calls: 348 await call.done_writing() 349 350 for worker in sub_workers: 351 await worker.stub.QuitWorker(control_pb2.Void()) 352 await worker.channel.close() 353 _LOGGER.info('Waiting for sub worker [%s] to quit...', worker) 354 await worker.process.wait() 355 _LOGGER.info('Sub worker [%s] quit', worker) 356 357 @staticmethod 358 async def CoreCount(unused_request, unused_context): 359 return control_pb2.CoreResponse(cores=_NUM_CORES) 360 361 async def QuitWorker(self, unused_request, unused_context): 362 _LOGGER.info('QuitWorker command received.') 363 self._quit_event.set() 364 return control_pb2.Void() 365 366 async def wait_for_quit(self): 367 await self._quit_event.wait() 368