• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import collections
17import logging
18import multiprocessing
19import os
20import sys
21import time
22from typing import Tuple
23
24import grpc
25from grpc.experimental import aio
26
27from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2,
28                                    stats_pb2, worker_service_pb2_grpc)
29from tests.qps import histogram
30from tests.unit import resources
31from tests.unit.framework.common import get_socket
32from tests_aio.benchmark import benchmark_client, benchmark_servicer
33
34_NUM_CORES = multiprocessing.cpu_count()
35_WORKER_ENTRY_FILE = os.path.join(
36    os.path.split(os.path.abspath(__file__))[0], 'worker.py')
37
38_LOGGER = logging.getLogger(__name__)
39
40
41class _SubWorker(
42        collections.namedtuple('_SubWorker',
43                               ['process', 'port', 'channel', 'stub'])):
44    """A data class that holds information about a child qps worker."""
45
46    def _repr(self):
47        return f'<_SubWorker pid={self.process.pid} port={self.port}>'
48
49    def __repr__(self):
50        return self._repr()
51
52    def __str__(self):
53        return self._repr()
54
55
56def _get_server_status(start_time: float, end_time: float,
57                       port: int) -> control_pb2.ServerStatus:
58    """Creates ServerStatus proto message."""
59    end_time = time.monotonic()
60    elapsed_time = end_time - start_time
61    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
62    stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
63                                  time_user=elapsed_time,
64                                  time_system=elapsed_time)
65    return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES)
66
67
68def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
69    """Creates a server object according to the ServerConfig."""
70    channel_args = tuple(
71        (arg.name,
72         arg.str_value) if arg.HasField('str_value') else (arg.name,
73                                                           int(arg.int_value))
74        for arg in config.channel_args)
75
76    server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),))
77    if config.server_type == control_pb2.ASYNC_SERVER:
78        servicer = benchmark_servicer.BenchmarkServicer()
79        benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
80            servicer, server)
81    elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
82        resp_size = config.payload_config.bytebuf_params.resp_size
83        servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size)
84        method_implementations = {
85            'StreamingCall':
86                grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
87            'UnaryCall':
88                grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
89        }
90        handler = grpc.method_handlers_generic_handler(
91            'grpc.testing.BenchmarkService', method_implementations)
92        server.add_generic_rpc_handlers((handler,))
93    else:
94        raise NotImplementedError('Unsupported server type {}'.format(
95            config.server_type))
96
97    if config.HasField('security_params'):  # Use SSL
98        server_creds = grpc.ssl_server_credentials(
99            ((resources.private_key(), resources.certificate_chain()),))
100        port = server.add_secure_port('[::]:{}'.format(config.port),
101                                      server_creds)
102    else:
103        port = server.add_insecure_port('[::]:{}'.format(config.port))
104
105    return server, port
106
107
108def _get_client_status(
109        start_time: float, end_time: float,
110        qps_data: histogram.Histogram) -> control_pb2.ClientStatus:
111    """Creates ClientStatus proto message."""
112    latencies = qps_data.get_data()
113    end_time = time.monotonic()
114    elapsed_time = end_time - start_time
115    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
116    stats = stats_pb2.ClientStats(latencies=latencies,
117                                  time_elapsed=elapsed_time,
118                                  time_user=elapsed_time,
119                                  time_system=elapsed_time)
120    return control_pb2.ClientStatus(stats=stats)
121
122
123def _create_client(
124        server: str, config: control_pb2.ClientConfig,
125        qps_data: histogram.Histogram) -> benchmark_client.BenchmarkClient:
126    """Creates a client object according to the ClientConfig."""
127    if config.load_params.WhichOneof('load') != 'closed_loop':
128        raise NotImplementedError(
129            f'Unsupported load parameter {config.load_params}')
130
131    if config.client_type == control_pb2.ASYNC_CLIENT:
132        if config.rpc_type == control_pb2.UNARY:
133            client_type = benchmark_client.UnaryAsyncBenchmarkClient
134        elif config.rpc_type == control_pb2.STREAMING:
135            client_type = benchmark_client.StreamingAsyncBenchmarkClient
136        elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER:
137            client_type = benchmark_client.ServerStreamingAsyncBenchmarkClient
138        else:
139            raise NotImplementedError(
140                f'Unsupported rpc_type [{config.rpc_type}]')
141    else:
142        raise NotImplementedError(
143            f'Unsupported client type {config.client_type}')
144
145    return client_type(server, config, qps_data)
146
147
148def _pick_an_unused_port() -> int:
149    """Picks an unused TCP port."""
150    _, port, sock = get_socket()
151    sock.close()
152    return port
153
154
155async def _create_sub_worker() -> _SubWorker:
156    """Creates a child qps worker as a subprocess."""
157    port = _pick_an_unused_port()
158
159    _LOGGER.info('Creating sub worker at port [%d]...', port)
160    process = await asyncio.create_subprocess_exec(sys.executable,
161                                                   _WORKER_ENTRY_FILE,
162                                                   '--driver_port', str(port))
163    _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port,
164                 process.pid)
165    channel = aio.insecure_channel(f'localhost:{port}')
166    _LOGGER.info('Waiting for sub worker at port [%d]', port)
167    await channel.channel_ready()
168    stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
169    return _SubWorker(
170        process=process,
171        port=port,
172        channel=channel,
173        stub=stub,
174    )
175
176
177class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
178    """Python Worker Server implementation."""
179
180    def __init__(self):
181        self._loop = asyncio.get_event_loop()
182        self._quit_event = asyncio.Event()
183
184    async def _run_single_server(self, config, request_iterator, context):
185        server, port = _create_server(config)
186        await server.start()
187        _LOGGER.info('Server started at port [%d]', port)
188
189        start_time = time.monotonic()
190        await context.write(_get_server_status(start_time, start_time, port))
191
192        async for request in request_iterator:
193            end_time = time.monotonic()
194            status = _get_server_status(start_time, end_time, port)
195            if request.mark.reset:
196                start_time = end_time
197            await context.write(status)
198        await server.stop(None)
199
200    async def RunServer(self, request_iterator, context):
201        config_request = await context.read()
202        config = config_request.setup
203        _LOGGER.info('Received ServerConfig: %s', config)
204
205        if config.server_processes <= 0:
206            _LOGGER.info('Using server_processes == [%d]', _NUM_CORES)
207            config.server_processes = _NUM_CORES
208
209        if config.port == 0:
210            config.port = _pick_an_unused_port()
211        _LOGGER.info('Port picked [%d]', config.port)
212
213        if config.server_processes == 1:
214            # If server_processes == 1, start the server in this process.
215            await self._run_single_server(config, request_iterator, context)
216        else:
217            # If server_processes > 1, offload to other processes.
218            sub_workers = await asyncio.gather(
219                *[_create_sub_worker() for _ in range(config.server_processes)])
220
221            calls = [worker.stub.RunServer() for worker in sub_workers]
222
223            config_request.setup.server_processes = 1
224
225            for call in calls:
226                await call.write(config_request)
227                # An empty status indicates the peer is ready
228                await call.read()
229
230            start_time = time.monotonic()
231            await context.write(
232                _get_server_status(
233                    start_time,
234                    start_time,
235                    config.port,
236                ))
237
238            _LOGGER.info('Servers are ready to serve.')
239
240            async for request in request_iterator:
241                end_time = time.monotonic()
242
243                for call in calls:
244                    await call.write(request)
245                    # Reports from sub workers doesn't matter
246                    await call.read()
247
248                status = _get_server_status(
249                    start_time,
250                    end_time,
251                    config.port,
252                )
253                if request.mark.reset:
254                    start_time = end_time
255                await context.write(status)
256
257            for call in calls:
258                await call.done_writing()
259
260            for worker in sub_workers:
261                await worker.stub.QuitWorker(control_pb2.Void())
262                await worker.channel.close()
263                _LOGGER.info('Waiting for [%s] to quit...', worker)
264                await worker.process.wait()
265
266    async def _run_single_client(self, config, request_iterator, context):
267        running_tasks = []
268        qps_data = histogram.Histogram(config.histogram_params.resolution,
269                                       config.histogram_params.max_possible)
270        start_time = time.monotonic()
271
272        # Create a client for each channel as asyncio.Task
273        for i in range(config.client_channels):
274            server = config.server_targets[i % len(config.server_targets)]
275            client = _create_client(server, config, qps_data)
276            _LOGGER.info('Client created against server [%s]', server)
277            running_tasks.append(self._loop.create_task(client.run()))
278
279        end_time = time.monotonic()
280        await context.write(_get_client_status(start_time, end_time, qps_data))
281
282        # Respond to stat requests
283        async for request in request_iterator:
284            end_time = time.monotonic()
285            status = _get_client_status(start_time, end_time, qps_data)
286            if request.mark.reset:
287                qps_data.reset()
288                start_time = time.monotonic()
289            await context.write(status)
290
291        # Cleanup the clients
292        for task in running_tasks:
293            task.cancel()
294
295    async def RunClient(self, request_iterator, context):
296        config_request = await context.read()
297        config = config_request.setup
298        _LOGGER.info('Received ClientConfig: %s', config)
299
300        if config.client_processes <= 0:
301            _LOGGER.info('client_processes can\'t be [%d]',
302                         config.client_processes)
303            _LOGGER.info('Using client_processes == [%d]', _NUM_CORES)
304            config.client_processes = _NUM_CORES
305
306        if config.client_processes == 1:
307            # If client_processes == 1, run the benchmark in this process.
308            await self._run_single_client(config, request_iterator, context)
309        else:
310            # If client_processes > 1, offload the work to other processes.
311            sub_workers = await asyncio.gather(
312                *[_create_sub_worker() for _ in range(config.client_processes)])
313
314            calls = [worker.stub.RunClient() for worker in sub_workers]
315
316            config_request.setup.client_processes = 1
317
318            for call in calls:
319                await call.write(config_request)
320                # An empty status indicates the peer is ready
321                await call.read()
322
323            start_time = time.monotonic()
324            result = histogram.Histogram(config.histogram_params.resolution,
325                                         config.histogram_params.max_possible)
326            end_time = time.monotonic()
327            await context.write(_get_client_status(start_time, end_time,
328                                                   result))
329
330            async for request in request_iterator:
331                end_time = time.monotonic()
332
333                for call in calls:
334                    _LOGGER.debug('Fetching status...')
335                    await call.write(request)
336                    sub_status = await call.read()
337                    result.merge(sub_status.stats.latencies)
338                    _LOGGER.debug('Update from sub worker count=[%d]',
339                                  sub_status.stats.latencies.count)
340
341                status = _get_client_status(start_time, end_time, result)
342                if request.mark.reset:
343                    result.reset()
344                    start_time = time.monotonic()
345                _LOGGER.debug('Reporting count=[%d]',
346                              status.stats.latencies.count)
347                await context.write(status)
348
349            for call in calls:
350                await call.done_writing()
351
352            for worker in sub_workers:
353                await worker.stub.QuitWorker(control_pb2.Void())
354                await worker.channel.close()
355                _LOGGER.info('Waiting for sub worker [%s] to quit...', worker)
356                await worker.process.wait()
357                _LOGGER.info('Sub worker [%s] quit', worker)
358
359    @staticmethod
360    async def CoreCount(unused_request, unused_context):
361        return control_pb2.CoreResponse(cores=_NUM_CORES)
362
363    async def QuitWorker(self, unused_request, unused_context):
364        _LOGGER.info('QuitWorker command received.')
365        self._quit_event.set()
366        return control_pb2.Void()
367
368    async def wait_for_quit(self):
369        await self._quit_event.wait()
370