• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import collections
17import logging
18import multiprocessing
19import os
20import sys
21import time
22from typing import Tuple
23
24import grpc
25from grpc.experimental import aio
26
27from src.proto.grpc.testing import benchmark_service_pb2_grpc
28from src.proto.grpc.testing import control_pb2
29from src.proto.grpc.testing import stats_pb2
30from src.proto.grpc.testing import worker_service_pb2_grpc
31from tests.qps import histogram
32from tests.unit import resources
33from tests.unit.framework.common import get_socket
34from tests_aio.benchmark import benchmark_client
35from tests_aio.benchmark import benchmark_servicer
36
37_NUM_CORES = multiprocessing.cpu_count()
38_WORKER_ENTRY_FILE = os.path.join(
39    os.path.split(os.path.abspath(__file__))[0], "worker.py"
40)
41
42_LOGGER = logging.getLogger(__name__)
43
44
45class _SubWorker(
46    collections.namedtuple("_SubWorker", ["process", "port", "channel", "stub"])
47):
48    """A data class that holds information about a child qps worker."""
49
50    def _repr(self):
51        return f"<_SubWorker pid={self.process.pid} port={self.port}>"
52
53    def __repr__(self):
54        return self._repr()
55
56    def __str__(self):
57        return self._repr()
58
59
60def _get_server_status(
61    start_time: float, end_time: float, port: int
62) -> control_pb2.ServerStatus:
63    """Creates ServerStatus proto message."""
64    end_time = time.monotonic()
65    elapsed_time = end_time - start_time
66    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
67    stats = stats_pb2.ServerStats(
68        time_elapsed=elapsed_time,
69        time_user=elapsed_time,
70        time_system=elapsed_time,
71    )
72    return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES)
73
74
75def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
76    """Creates a server object according to the ServerConfig."""
77    channel_args = tuple(
78        (arg.name, arg.str_value)
79        if arg.HasField("str_value")
80        else (arg.name, int(arg.int_value))
81        for arg in config.channel_args
82    )
83
84    server = aio.server(options=channel_args + (("grpc.so_reuseport", 1),))
85    if config.server_type == control_pb2.ASYNC_SERVER:
86        servicer = benchmark_servicer.BenchmarkServicer()
87        benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
88            servicer, server
89        )
90    elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
91        resp_size = config.payload_config.bytebuf_params.resp_size
92        servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size)
93        method_implementations = {
94            "StreamingCall": grpc.stream_stream_rpc_method_handler(
95                servicer.StreamingCall
96            ),
97            "UnaryCall": grpc.unary_unary_rpc_method_handler(
98                servicer.UnaryCall
99            ),
100        }
101        handler = grpc.method_handlers_generic_handler(
102            "grpc.testing.BenchmarkService", method_implementations
103        )
104        server.add_generic_rpc_handlers((handler,))
105    else:
106        raise NotImplementedError(
107            "Unsupported server type {}".format(config.server_type)
108        )
109
110    if config.HasField("security_params"):  # Use SSL
111        server_creds = grpc.ssl_server_credentials(
112            ((resources.private_key(), resources.certificate_chain()),)
113        )
114        port = server.add_secure_port(
115            "[::]:{}".format(config.port), server_creds
116        )
117    else:
118        port = server.add_insecure_port("[::]:{}".format(config.port))
119
120    return server, port
121
122
123def _get_client_status(
124    start_time: float, end_time: float, qps_data: histogram.Histogram
125) -> control_pb2.ClientStatus:
126    """Creates ClientStatus proto message."""
127    latencies = qps_data.get_data()
128    end_time = time.monotonic()
129    elapsed_time = end_time - start_time
130    # TODO(lidiz) Collect accurate time system to compute QPS/core-second.
131    stats = stats_pb2.ClientStats(
132        latencies=latencies,
133        time_elapsed=elapsed_time,
134        time_user=elapsed_time,
135        time_system=elapsed_time,
136    )
137    return control_pb2.ClientStatus(stats=stats)
138
139
140def _create_client(
141    server: str, config: control_pb2.ClientConfig, qps_data: histogram.Histogram
142) -> benchmark_client.BenchmarkClient:
143    """Creates a client object according to the ClientConfig."""
144    if config.load_params.WhichOneof("load") != "closed_loop":
145        raise NotImplementedError(
146            f"Unsupported load parameter {config.load_params}"
147        )
148
149    if config.client_type == control_pb2.ASYNC_CLIENT:
150        if config.rpc_type == control_pb2.UNARY:
151            client_type = benchmark_client.UnaryAsyncBenchmarkClient
152        elif config.rpc_type == control_pb2.STREAMING:
153            client_type = benchmark_client.StreamingAsyncBenchmarkClient
154        elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER:
155            client_type = benchmark_client.ServerStreamingAsyncBenchmarkClient
156        else:
157            raise NotImplementedError(
158                f"Unsupported rpc_type [{config.rpc_type}]"
159            )
160    else:
161        raise NotImplementedError(
162            f"Unsupported client type {config.client_type}"
163        )
164
165    return client_type(server, config, qps_data)
166
167
168def _pick_an_unused_port() -> int:
169    """Picks an unused TCP port."""
170    _, port, sock = get_socket()
171    sock.close()
172    return port
173
174
175async def _create_sub_worker() -> _SubWorker:
176    """Creates a child qps worker as a subprocess."""
177    port = _pick_an_unused_port()
178
179    _LOGGER.info("Creating sub worker at port [%d]...", port)
180    process = await asyncio.create_subprocess_exec(
181        sys.executable, _WORKER_ENTRY_FILE, "--driver_port", str(port)
182    )
183    _LOGGER.info(
184        "Created sub worker process for port [%d] at pid [%d]",
185        port,
186        process.pid,
187    )
188    channel = aio.insecure_channel(f"localhost:{port}")
189    _LOGGER.info("Waiting for sub worker at port [%d]", port)
190    await channel.channel_ready()
191    stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
192    return _SubWorker(
193        process=process,
194        port=port,
195        channel=channel,
196        stub=stub,
197    )
198
199
200class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
201    """Python Worker Server implementation."""
202
203    def __init__(self):
204        self._loop = asyncio.get_event_loop()
205        self._quit_event = asyncio.Event()
206
207    async def _run_single_server(self, config, request_iterator, context):
208        server, port = _create_server(config)
209        await server.start()
210        _LOGGER.info("Server started at port [%d]", port)
211
212        start_time = time.monotonic()
213        await context.write(_get_server_status(start_time, start_time, port))
214
215        async for request in request_iterator:
216            end_time = time.monotonic()
217            status = _get_server_status(start_time, end_time, port)
218            if request.mark.reset:
219                start_time = end_time
220            await context.write(status)
221        await server.stop(None)
222
223    async def RunServer(self, request_iterator, context):
224        config_request = await context.read()
225        config = config_request.setup
226        _LOGGER.info("Received ServerConfig: %s", config)
227
228        if config.server_processes <= 0:
229            _LOGGER.info("Using server_processes == [%d]", _NUM_CORES)
230            config.server_processes = _NUM_CORES
231
232        if config.port == 0:
233            config.port = _pick_an_unused_port()
234        _LOGGER.info("Port picked [%d]", config.port)
235
236        if config.server_processes == 1:
237            # If server_processes == 1, start the server in this process.
238            await self._run_single_server(config, request_iterator, context)
239        else:
240            # If server_processes > 1, offload to other processes.
241            sub_workers = await asyncio.gather(
242                *[_create_sub_worker() for _ in range(config.server_processes)]
243            )
244
245            calls = [worker.stub.RunServer() for worker in sub_workers]
246
247            config_request.setup.server_processes = 1
248
249            for call in calls:
250                await call.write(config_request)
251                # An empty status indicates the peer is ready
252                await call.read()
253
254            start_time = time.monotonic()
255            await context.write(
256                _get_server_status(
257                    start_time,
258                    start_time,
259                    config.port,
260                )
261            )
262
263            _LOGGER.info("Servers are ready to serve.")
264
265            async for request in request_iterator:
266                end_time = time.monotonic()
267
268                for call in calls:
269                    await call.write(request)
270                    # Reports from sub workers doesn't matter
271                    await call.read()
272
273                status = _get_server_status(
274                    start_time,
275                    end_time,
276                    config.port,
277                )
278                if request.mark.reset:
279                    start_time = end_time
280                await context.write(status)
281
282            for call in calls:
283                await call.done_writing()
284
285            for worker in sub_workers:
286                await worker.stub.QuitWorker(control_pb2.Void())
287                await worker.channel.close()
288                _LOGGER.info("Waiting for [%s] to quit...", worker)
289                await worker.process.wait()
290
291    async def _run_single_client(self, config, request_iterator, context):
292        running_tasks = []
293        qps_data = histogram.Histogram(
294            config.histogram_params.resolution,
295            config.histogram_params.max_possible,
296        )
297        start_time = time.monotonic()
298
299        # Create a client for each channel as asyncio.Task
300        for i in range(config.client_channels):
301            server = config.server_targets[i % len(config.server_targets)]
302            client = _create_client(server, config, qps_data)
303            _LOGGER.info("Client created against server [%s]", server)
304            running_tasks.append(self._loop.create_task(client.run()))
305
306        end_time = time.monotonic()
307        await context.write(_get_client_status(start_time, end_time, qps_data))
308
309        # Respond to stat requests
310        async for request in request_iterator:
311            end_time = time.monotonic()
312            status = _get_client_status(start_time, end_time, qps_data)
313            if request.mark.reset:
314                qps_data.reset()
315                start_time = time.monotonic()
316            await context.write(status)
317
318        # Cleanup the clients
319        for task in running_tasks:
320            task.cancel()
321
322    async def RunClient(self, request_iterator, context):
323        config_request = await context.read()
324        config = config_request.setup
325        _LOGGER.info("Received ClientConfig: %s", config)
326
327        if config.client_processes <= 0:
328            _LOGGER.info(
329                "client_processes can't be [%d]", config.client_processes
330            )
331            _LOGGER.info("Using client_processes == [%d]", _NUM_CORES)
332            config.client_processes = _NUM_CORES
333
334        if config.client_processes == 1:
335            # If client_processes == 1, run the benchmark in this process.
336            await self._run_single_client(config, request_iterator, context)
337        else:
338            # If client_processes > 1, offload the work to other processes.
339            sub_workers = await asyncio.gather(
340                *[_create_sub_worker() for _ in range(config.client_processes)]
341            )
342
343            calls = [worker.stub.RunClient() for worker in sub_workers]
344
345            config_request.setup.client_processes = 1
346
347            for call in calls:
348                await call.write(config_request)
349                # An empty status indicates the peer is ready
350                await call.read()
351
352            start_time = time.monotonic()
353            result = histogram.Histogram(
354                config.histogram_params.resolution,
355                config.histogram_params.max_possible,
356            )
357            end_time = time.monotonic()
358            await context.write(
359                _get_client_status(start_time, end_time, result)
360            )
361
362            async for request in request_iterator:
363                end_time = time.monotonic()
364
365                for call in calls:
366                    _LOGGER.debug("Fetching status...")
367                    await call.write(request)
368                    sub_status = await call.read()
369                    result.merge(sub_status.stats.latencies)
370                    _LOGGER.debug(
371                        "Update from sub worker count=[%d]",
372                        sub_status.stats.latencies.count,
373                    )
374
375                status = _get_client_status(start_time, end_time, result)
376                if request.mark.reset:
377                    result.reset()
378                    start_time = time.monotonic()
379                _LOGGER.debug(
380                    "Reporting count=[%d]", status.stats.latencies.count
381                )
382                await context.write(status)
383
384            for call in calls:
385                await call.done_writing()
386
387            for worker in sub_workers:
388                await worker.stub.QuitWorker(control_pb2.Void())
389                await worker.channel.close()
390                _LOGGER.info("Waiting for sub worker [%s] to quit...", worker)
391                await worker.process.wait()
392                _LOGGER.info("Sub worker [%s] quit", worker)
393
394    @staticmethod
395    async def CoreCount(unused_request, unused_context):
396        return control_pb2.CoreResponse(cores=_NUM_CORES)
397
398    async def QuitWorker(self, unused_request, unused_context):
399        _LOGGER.info("QuitWorker command received.")
400        self._quit_event.set()
401        return control_pb2.Void()
402
403    async def wait_for_quit(self):
404        await self._quit_event.wait()
405