1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import collections 17import logging 18import multiprocessing 19import os 20import sys 21import time 22from typing import Tuple 23 24import grpc 25from grpc.experimental import aio 26 27from src.proto.grpc.testing import benchmark_service_pb2_grpc 28from src.proto.grpc.testing import control_pb2 29from src.proto.grpc.testing import stats_pb2 30from src.proto.grpc.testing import worker_service_pb2_grpc 31from tests.qps import histogram 32from tests.unit import resources 33from tests.unit.framework.common import get_socket 34from tests_aio.benchmark import benchmark_client 35from tests_aio.benchmark import benchmark_servicer 36 37_NUM_CORES = multiprocessing.cpu_count() 38_WORKER_ENTRY_FILE = os.path.join( 39 os.path.split(os.path.abspath(__file__))[0], "worker.py" 40) 41 42_LOGGER = logging.getLogger(__name__) 43 44 45class _SubWorker( 46 collections.namedtuple("_SubWorker", ["process", "port", "channel", "stub"]) 47): 48 """A data class that holds information about a child qps worker.""" 49 50 def _repr(self): 51 return f"<_SubWorker pid={self.process.pid} port={self.port}>" 52 53 def __repr__(self): 54 return self._repr() 55 56 def __str__(self): 57 return self._repr() 58 59 60def _get_server_status( 61 start_time: float, end_time: float, port: int 62) -> control_pb2.ServerStatus: 63 """Creates ServerStatus proto message.""" 64 end_time = time.monotonic() 65 elapsed_time = end_time - start_time 66 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 67 stats = stats_pb2.ServerStats( 68 time_elapsed=elapsed_time, 69 time_user=elapsed_time, 70 time_system=elapsed_time, 71 ) 72 return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) 73 74 75def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: 76 """Creates a server object according to the ServerConfig.""" 77 channel_args = tuple( 78 (arg.name, arg.str_value) 79 if arg.HasField("str_value") 80 else (arg.name, int(arg.int_value)) 81 for arg in config.channel_args 82 ) 83 84 server = aio.server(options=channel_args + (("grpc.so_reuseport", 1),)) 85 if config.server_type == control_pb2.ASYNC_SERVER: 86 servicer = benchmark_servicer.BenchmarkServicer() 87 benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( 88 servicer, server 89 ) 90 elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: 91 resp_size = config.payload_config.bytebuf_params.resp_size 92 servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) 93 method_implementations = { 94 "StreamingCall": grpc.stream_stream_rpc_method_handler( 95 servicer.StreamingCall 96 ), 97 "UnaryCall": grpc.unary_unary_rpc_method_handler( 98 servicer.UnaryCall 99 ), 100 } 101 handler = grpc.method_handlers_generic_handler( 102 "grpc.testing.BenchmarkService", method_implementations 103 ) 104 server.add_generic_rpc_handlers((handler,)) 105 else: 106 raise NotImplementedError( 107 "Unsupported server type {}".format(config.server_type) 108 ) 109 110 if config.HasField("security_params"): # Use SSL 111 server_creds = grpc.ssl_server_credentials( 112 ((resources.private_key(), resources.certificate_chain()),) 113 ) 114 port = server.add_secure_port( 115 "[::]:{}".format(config.port), server_creds 116 ) 117 else: 118 port = server.add_insecure_port("[::]:{}".format(config.port)) 119 120 return server, port 121 122 123def _get_client_status( 124 start_time: float, end_time: float, qps_data: histogram.Histogram 125) -> control_pb2.ClientStatus: 126 """Creates ClientStatus proto message.""" 127 latencies = qps_data.get_data() 128 end_time = time.monotonic() 129 elapsed_time = end_time - start_time 130 # TODO(lidiz) Collect accurate time system to compute QPS/core-second. 131 stats = stats_pb2.ClientStats( 132 latencies=latencies, 133 time_elapsed=elapsed_time, 134 time_user=elapsed_time, 135 time_system=elapsed_time, 136 ) 137 return control_pb2.ClientStatus(stats=stats) 138 139 140def _create_client( 141 server: str, config: control_pb2.ClientConfig, qps_data: histogram.Histogram 142) -> benchmark_client.BenchmarkClient: 143 """Creates a client object according to the ClientConfig.""" 144 if config.load_params.WhichOneof("load") != "closed_loop": 145 raise NotImplementedError( 146 f"Unsupported load parameter {config.load_params}" 147 ) 148 149 if config.client_type == control_pb2.ASYNC_CLIENT: 150 if config.rpc_type == control_pb2.UNARY: 151 client_type = benchmark_client.UnaryAsyncBenchmarkClient 152 elif config.rpc_type == control_pb2.STREAMING: 153 client_type = benchmark_client.StreamingAsyncBenchmarkClient 154 elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER: 155 client_type = benchmark_client.ServerStreamingAsyncBenchmarkClient 156 else: 157 raise NotImplementedError( 158 f"Unsupported rpc_type [{config.rpc_type}]" 159 ) 160 else: 161 raise NotImplementedError( 162 f"Unsupported client type {config.client_type}" 163 ) 164 165 return client_type(server, config, qps_data) 166 167 168def _pick_an_unused_port() -> int: 169 """Picks an unused TCP port.""" 170 _, port, sock = get_socket() 171 sock.close() 172 return port 173 174 175async def _create_sub_worker() -> _SubWorker: 176 """Creates a child qps worker as a subprocess.""" 177 port = _pick_an_unused_port() 178 179 _LOGGER.info("Creating sub worker at port [%d]...", port) 180 process = await asyncio.create_subprocess_exec( 181 sys.executable, _WORKER_ENTRY_FILE, "--driver_port", str(port) 182 ) 183 _LOGGER.info( 184 "Created sub worker process for port [%d] at pid [%d]", 185 port, 186 process.pid, 187 ) 188 channel = aio.insecure_channel(f"localhost:{port}") 189 _LOGGER.info("Waiting for sub worker at port [%d]", port) 190 await channel.channel_ready() 191 stub = worker_service_pb2_grpc.WorkerServiceStub(channel) 192 return _SubWorker( 193 process=process, 194 port=port, 195 channel=channel, 196 stub=stub, 197 ) 198 199 200class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): 201 """Python Worker Server implementation.""" 202 203 def __init__(self): 204 self._loop = asyncio.get_event_loop() 205 self._quit_event = asyncio.Event() 206 207 async def _run_single_server(self, config, request_iterator, context): 208 server, port = _create_server(config) 209 await server.start() 210 _LOGGER.info("Server started at port [%d]", port) 211 212 start_time = time.monotonic() 213 await context.write(_get_server_status(start_time, start_time, port)) 214 215 async for request in request_iterator: 216 end_time = time.monotonic() 217 status = _get_server_status(start_time, end_time, port) 218 if request.mark.reset: 219 start_time = end_time 220 await context.write(status) 221 await server.stop(None) 222 223 async def RunServer(self, request_iterator, context): 224 config_request = await context.read() 225 config = config_request.setup 226 _LOGGER.info("Received ServerConfig: %s", config) 227 228 if config.server_processes <= 0: 229 _LOGGER.info("Using server_processes == [%d]", _NUM_CORES) 230 config.server_processes = _NUM_CORES 231 232 if config.port == 0: 233 config.port = _pick_an_unused_port() 234 _LOGGER.info("Port picked [%d]", config.port) 235 236 if config.server_processes == 1: 237 # If server_processes == 1, start the server in this process. 238 await self._run_single_server(config, request_iterator, context) 239 else: 240 # If server_processes > 1, offload to other processes. 241 sub_workers = await asyncio.gather( 242 *[_create_sub_worker() for _ in range(config.server_processes)] 243 ) 244 245 calls = [worker.stub.RunServer() for worker in sub_workers] 246 247 config_request.setup.server_processes = 1 248 249 for call in calls: 250 await call.write(config_request) 251 # An empty status indicates the peer is ready 252 await call.read() 253 254 start_time = time.monotonic() 255 await context.write( 256 _get_server_status( 257 start_time, 258 start_time, 259 config.port, 260 ) 261 ) 262 263 _LOGGER.info("Servers are ready to serve.") 264 265 async for request in request_iterator: 266 end_time = time.monotonic() 267 268 for call in calls: 269 await call.write(request) 270 # Reports from sub workers doesn't matter 271 await call.read() 272 273 status = _get_server_status( 274 start_time, 275 end_time, 276 config.port, 277 ) 278 if request.mark.reset: 279 start_time = end_time 280 await context.write(status) 281 282 for call in calls: 283 await call.done_writing() 284 285 for worker in sub_workers: 286 await worker.stub.QuitWorker(control_pb2.Void()) 287 await worker.channel.close() 288 _LOGGER.info("Waiting for [%s] to quit...", worker) 289 await worker.process.wait() 290 291 async def _run_single_client(self, config, request_iterator, context): 292 running_tasks = [] 293 qps_data = histogram.Histogram( 294 config.histogram_params.resolution, 295 config.histogram_params.max_possible, 296 ) 297 start_time = time.monotonic() 298 299 # Create a client for each channel as asyncio.Task 300 for i in range(config.client_channels): 301 server = config.server_targets[i % len(config.server_targets)] 302 client = _create_client(server, config, qps_data) 303 _LOGGER.info("Client created against server [%s]", server) 304 running_tasks.append(self._loop.create_task(client.run())) 305 306 end_time = time.monotonic() 307 await context.write(_get_client_status(start_time, end_time, qps_data)) 308 309 # Respond to stat requests 310 async for request in request_iterator: 311 end_time = time.monotonic() 312 status = _get_client_status(start_time, end_time, qps_data) 313 if request.mark.reset: 314 qps_data.reset() 315 start_time = time.monotonic() 316 await context.write(status) 317 318 # Cleanup the clients 319 for task in running_tasks: 320 task.cancel() 321 322 async def RunClient(self, request_iterator, context): 323 config_request = await context.read() 324 config = config_request.setup 325 _LOGGER.info("Received ClientConfig: %s", config) 326 327 if config.client_processes <= 0: 328 _LOGGER.info( 329 "client_processes can't be [%d]", config.client_processes 330 ) 331 _LOGGER.info("Using client_processes == [%d]", _NUM_CORES) 332 config.client_processes = _NUM_CORES 333 334 if config.client_processes == 1: 335 # If client_processes == 1, run the benchmark in this process. 336 await self._run_single_client(config, request_iterator, context) 337 else: 338 # If client_processes > 1, offload the work to other processes. 339 sub_workers = await asyncio.gather( 340 *[_create_sub_worker() for _ in range(config.client_processes)] 341 ) 342 343 calls = [worker.stub.RunClient() for worker in sub_workers] 344 345 config_request.setup.client_processes = 1 346 347 for call in calls: 348 await call.write(config_request) 349 # An empty status indicates the peer is ready 350 await call.read() 351 352 start_time = time.monotonic() 353 result = histogram.Histogram( 354 config.histogram_params.resolution, 355 config.histogram_params.max_possible, 356 ) 357 end_time = time.monotonic() 358 await context.write( 359 _get_client_status(start_time, end_time, result) 360 ) 361 362 async for request in request_iterator: 363 end_time = time.monotonic() 364 365 for call in calls: 366 _LOGGER.debug("Fetching status...") 367 await call.write(request) 368 sub_status = await call.read() 369 result.merge(sub_status.stats.latencies) 370 _LOGGER.debug( 371 "Update from sub worker count=[%d]", 372 sub_status.stats.latencies.count, 373 ) 374 375 status = _get_client_status(start_time, end_time, result) 376 if request.mark.reset: 377 result.reset() 378 start_time = time.monotonic() 379 _LOGGER.debug( 380 "Reporting count=[%d]", status.stats.latencies.count 381 ) 382 await context.write(status) 383 384 for call in calls: 385 await call.done_writing() 386 387 for worker in sub_workers: 388 await worker.stub.QuitWorker(control_pb2.Void()) 389 await worker.channel.close() 390 _LOGGER.info("Waiting for sub worker [%s] to quit...", worker) 391 await worker.process.wait() 392 _LOGGER.info("Sub worker [%s] quit", worker) 393 394 @staticmethod 395 async def CoreCount(unused_request, unused_context): 396 return control_pb2.CoreResponse(cores=_NUM_CORES) 397 398 async def QuitWorker(self, unused_request, unused_context): 399 _LOGGER.info("QuitWorker command received.") 400 self._quit_event.set() 401 return control_pb2.Void() 402 403 async def wait_for_quit(self): 404 await self._quit_event.wait() 405