• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import collections
17import logging
18import multiprocessing
19import os
20import sys
21import time
22from typing import Tuple
23
24import grpc
25from grpc.experimental import aio
26
27from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2,
28                                    stats_pb2, worker_service_pb2_grpc)
29from tests.qps import histogram
30from tests.unit import resources
31from tests.unit.framework.common import get_socket
32from tests_aio.benchmark import benchmark_client, benchmark_servicer
33
34_NUM_CORES = multiprocessing.cpu_count()
35_WORKER_ENTRY_FILE = os.path.join(
36    os.path.split(os.path.abspath(__file__))[0], 'worker.py')
37
38_LOGGER = logging.getLogger(__name__)
39
40
41class _SubWorker(
42        collections.namedtuple('_SubWorker',
43                               ['process', 'port', 'channel', 'stub'])):
44    """A data class that holds information about a child qps worker."""
45
46    def _repr(self):
47        return f'<_SubWorker pid={self.process.pid} port={self.port}>'
48
49    def __repr__(self):
50        return self._repr()
51
52    def __str__(self):
53        return self._repr()
54
55
56def _get_server_status(start_time: float, end_time: float,
57                       port: int) -> control_pb2.ServerStatus:
58    """Creates ServerStatus proto message."""
59    end_time = time.monotonic()
60    elapsed_time = end_time - start_time
61    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
62    stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
63                                  time_user=elapsed_time,
64                                  time_system=elapsed_time)
65    return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES)
66
67
68def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
69    """Creates a server object according to the ServerConfig."""
70    channel_args = tuple(
71        (arg.name,
72         arg.str_value) if arg.HasField('str_value') else (arg.name,
73                                                           int(arg.int_value))
74        for arg in config.channel_args)
75
76    server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),))
77    if config.server_type == control_pb2.ASYNC_SERVER:
78        servicer = benchmark_servicer.BenchmarkServicer()
79        benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
80            servicer, server)
81    elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
82        resp_size = config.payload_config.bytebuf_params.resp_size
83        servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size)
84        method_implementations = {
85            'StreamingCall':
86                grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
87            'UnaryCall':
88                grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
89        }
90        handler = grpc.method_handlers_generic_handler(
91            'grpc.testing.BenchmarkService', method_implementations)
92        server.add_generic_rpc_handlers((handler,))
93    else:
94        raise NotImplementedError('Unsupported server type {}'.format(
95            config.server_type))
96
97    if config.HasField('security_params'):  # Use SSL
98        server_creds = grpc.ssl_server_credentials(
99            ((resources.private_key(), resources.certificate_chain()),))
100        port = server.add_secure_port('[::]:{}'.format(config.port),
101                                      server_creds)
102    else:
103        port = server.add_insecure_port('[::]:{}'.format(config.port))
104
105    return server, port
106
107
108def _get_client_status(start_time: float, end_time: float,
109                       qps_data: histogram.Histogram
110                      ) -> control_pb2.ClientStatus:
111    """Creates ClientStatus proto message."""
112    latencies = qps_data.get_data()
113    end_time = time.monotonic()
114    elapsed_time = end_time - start_time
115    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
116    stats = stats_pb2.ClientStats(latencies=latencies,
117                                  time_elapsed=elapsed_time,
118                                  time_user=elapsed_time,
119                                  time_system=elapsed_time)
120    return control_pb2.ClientStatus(stats=stats)
121
122
123def _create_client(server: str, config: control_pb2.ClientConfig,
124                   qps_data: histogram.Histogram
125                  ) -> benchmark_client.BenchmarkClient:
126    """Creates a client object according to the ClientConfig."""
127    if config.load_params.WhichOneof('load') != 'closed_loop':
128        raise NotImplementedError(
129            f'Unsupported load parameter {config.load_params}')
130
131    if config.client_type == control_pb2.ASYNC_CLIENT:
132        if config.rpc_type == control_pb2.UNARY:
133            client_type = benchmark_client.UnaryAsyncBenchmarkClient
134        elif config.rpc_type == control_pb2.STREAMING:
135            client_type = benchmark_client.StreamingAsyncBenchmarkClient
136        else:
137            raise NotImplementedError(
138                f'Unsupported rpc_type [{config.rpc_type}]')
139    else:
140        raise NotImplementedError(
141            f'Unsupported client type {config.client_type}')
142
143    return client_type(server, config, qps_data)
144
145
146def _pick_an_unused_port() -> int:
147    """Picks an unused TCP port."""
148    _, port, sock = get_socket()
149    sock.close()
150    return port
151
152
153async def _create_sub_worker() -> _SubWorker:
154    """Creates a child qps worker as a subprocess."""
155    port = _pick_an_unused_port()
156
157    _LOGGER.info('Creating sub worker at port [%d]...', port)
158    process = await asyncio.create_subprocess_exec(sys.executable,
159                                                   _WORKER_ENTRY_FILE,
160                                                   '--driver_port', str(port))
161    _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port,
162                 process.pid)
163    channel = aio.insecure_channel(f'localhost:{port}')
164    _LOGGER.info('Waiting for sub worker at port [%d]', port)
165    await channel.channel_ready()
166    stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
167    return _SubWorker(
168        process=process,
169        port=port,
170        channel=channel,
171        stub=stub,
172    )
173
174
175class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
176    """Python Worker Server implementation."""
177
178    def __init__(self):
179        self._loop = asyncio.get_event_loop()
180        self._quit_event = asyncio.Event()
181
182    async def _run_single_server(self, config, request_iterator, context):
183        server, port = _create_server(config)
184        await server.start()
185        _LOGGER.info('Server started at port [%d]', port)
186
187        start_time = time.monotonic()
188        await context.write(_get_server_status(start_time, start_time, port))
189
190        async for request in request_iterator:
191            end_time = time.monotonic()
192            status = _get_server_status(start_time, end_time, port)
193            if request.mark.reset:
194                start_time = end_time
195            await context.write(status)
196        await server.stop(None)
197
198    async def RunServer(self, request_iterator, context):
199        config_request = await context.read()
200        config = config_request.setup
201        _LOGGER.info('Received ServerConfig: %s', config)
202
203        if config.server_processes <= 0:
204            _LOGGER.info('Using server_processes == [%d]', _NUM_CORES)
205            config.server_processes = _NUM_CORES
206
207        if config.port == 0:
208            config.port = _pick_an_unused_port()
209        _LOGGER.info('Port picked [%d]', config.port)
210
211        if config.server_processes == 1:
212            # If server_processes == 1, start the server in this process.
213            await self._run_single_server(config, request_iterator, context)
214        else:
215            # If server_processes > 1, offload to other processes.
216            sub_workers = await asyncio.gather(*(
217                _create_sub_worker() for _ in range(config.server_processes)))
218
219            calls = [worker.stub.RunServer() for worker in sub_workers]
220
221            config_request.setup.server_processes = 1
222
223            for call in calls:
224                await call.write(config_request)
225                # An empty status indicates the peer is ready
226                await call.read()
227
228            start_time = time.monotonic()
229            await context.write(
230                _get_server_status(
231                    start_time,
232                    start_time,
233                    config.port,
234                ))
235
236            _LOGGER.info('Servers are ready to serve.')
237
238            async for request in request_iterator:
239                end_time = time.monotonic()
240
241                for call in calls:
242                    await call.write(request)
243                    # Reports from sub workers doesn't matter
244                    await call.read()
245
246                status = _get_server_status(
247                    start_time,
248                    end_time,
249                    config.port,
250                )
251                if request.mark.reset:
252                    start_time = end_time
253                await context.write(status)
254
255            for call in calls:
256                await call.done_writing()
257
258            for worker in sub_workers:
259                await worker.stub.QuitWorker(control_pb2.Void())
260                await worker.channel.close()
261                _LOGGER.info('Waiting for [%s] to quit...', worker)
262                await worker.process.wait()
263
264    async def _run_single_client(self, config, request_iterator, context):
265        running_tasks = []
266        qps_data = histogram.Histogram(config.histogram_params.resolution,
267                                       config.histogram_params.max_possible)
268        start_time = time.monotonic()
269
270        # Create a client for each channel as asyncio.Task
271        for i in range(config.client_channels):
272            server = config.server_targets[i % len(config.server_targets)]
273            client = _create_client(server, config, qps_data)
274            _LOGGER.info('Client created against server [%s]', server)
275            running_tasks.append(self._loop.create_task(client.run()))
276
277        end_time = time.monotonic()
278        await context.write(_get_client_status(start_time, end_time, qps_data))
279
280        # Respond to stat requests
281        async for request in request_iterator:
282            end_time = time.monotonic()
283            status = _get_client_status(start_time, end_time, qps_data)
284            if request.mark.reset:
285                qps_data.reset()
286                start_time = time.monotonic()
287            await context.write(status)
288
289        # Cleanup the clients
290        for task in running_tasks:
291            task.cancel()
292
293    async def RunClient(self, request_iterator, context):
294        config_request = await context.read()
295        config = config_request.setup
296        _LOGGER.info('Received ClientConfig: %s', config)
297
298        if config.client_processes <= 0:
299            _LOGGER.info('client_processes can\'t be [%d]',
300                         config.client_processes)
301            _LOGGER.info('Using client_processes == [%d]', _NUM_CORES)
302            config.client_processes = _NUM_CORES
303
304        if config.client_processes == 1:
305            # If client_processes == 1, run the benchmark in this process.
306            await self._run_single_client(config, request_iterator, context)
307        else:
308            # If client_processes > 1, offload the work to other processes.
309            sub_workers = await asyncio.gather(*(
310                _create_sub_worker() for _ in range(config.client_processes)))
311
312            calls = [worker.stub.RunClient() for worker in sub_workers]
313
314            config_request.setup.client_processes = 1
315
316            for call in calls:
317                await call.write(config_request)
318                # An empty status indicates the peer is ready
319                await call.read()
320
321            start_time = time.monotonic()
322            result = histogram.Histogram(config.histogram_params.resolution,
323                                         config.histogram_params.max_possible)
324            end_time = time.monotonic()
325            await context.write(_get_client_status(start_time, end_time,
326                                                   result))
327
328            async for request in request_iterator:
329                end_time = time.monotonic()
330
331                for call in calls:
332                    _LOGGER.debug('Fetching status...')
333                    await call.write(request)
334                    sub_status = await call.read()
335                    result.merge(sub_status.stats.latencies)
336                    _LOGGER.debug('Update from sub worker count=[%d]',
337                                  sub_status.stats.latencies.count)
338
339                status = _get_client_status(start_time, end_time, result)
340                if request.mark.reset:
341                    result.reset()
342                    start_time = time.monotonic()
343                _LOGGER.debug('Reporting count=[%d]',
344                              status.stats.latencies.count)
345                await context.write(status)
346
347            for call in calls:
348                await call.done_writing()
349
350            for worker in sub_workers:
351                await worker.stub.QuitWorker(control_pb2.Void())
352                await worker.channel.close()
353                _LOGGER.info('Waiting for sub worker [%s] to quit...', worker)
354                await worker.process.wait()
355                _LOGGER.info('Sub worker [%s] quit', worker)
356
357    @staticmethod
358    async def CoreCount(unused_request, unused_context):
359        return control_pb2.CoreResponse(cores=_NUM_CORES)
360
361    async def QuitWorker(self, unused_request, unused_context):
362        _LOGGER.info('QuitWorker command received.')
363        self._quit_event.set()
364        return control_pb2.Void()
365
366    async def wait_for_quit(self):
367        await self._quit_event.wait()
368