• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python AsyncIO Benchmark Clients."""
15
16import abc
17import asyncio
18import time
19import logging
20import random
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2,
26                                    messages_pb2)
27from tests.qps import histogram
28from tests.unit import resources
29
30
31class GenericStub(object):
32
33    def __init__(self, channel: aio.Channel):
34        self.UnaryCall = channel.unary_unary(
35            '/grpc.testing.BenchmarkService/UnaryCall')
36        self.StreamingCall = channel.stream_stream(
37            '/grpc.testing.BenchmarkService/StreamingCall')
38
39
40class BenchmarkClient(abc.ABC):
41    """Benchmark client interface that exposes a non-blocking send_request()."""
42
43    def __init__(self, address: str, config: control_pb2.ClientConfig,
44                 hist: histogram.Histogram):
45        # Disables underlying reuse of subchannels
46        unique_option = (('iv', random.random()),)
47
48        # Parses the channel argument from config
49        channel_args = tuple(
50            (arg.name, arg.str_value) if arg.HasField('str_value') else (
51                arg.name, int(arg.int_value)) for arg in config.channel_args)
52
53        # Creates the channel
54        if config.HasField('security_params'):
55            channel_credentials = grpc.ssl_channel_credentials(
56                resources.test_root_certificates(),)
57            server_host_override_option = ((
58                'grpc.ssl_target_name_override',
59                config.security_params.server_host_override,
60            ),)
61            self._channel = aio.secure_channel(
62                address, channel_credentials,
63                unique_option + channel_args + server_host_override_option)
64        else:
65            self._channel = aio.insecure_channel(address,
66                                                 options=unique_option +
67                                                 channel_args)
68
69        # Creates the stub
70        if config.payload_config.WhichOneof('payload') == 'simple_params':
71            self._generic = False
72            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
73                self._channel)
74            payload = messages_pb2.Payload(
75                body=b'\0' * config.payload_config.simple_params.req_size)
76            self._request = messages_pb2.SimpleRequest(
77                payload=payload,
78                response_size=config.payload_config.simple_params.resp_size)
79        else:
80            self._generic = True
81            self._stub = GenericStub(self._channel)
82            self._request = b'\0' * config.payload_config.bytebuf_params.req_size
83
84        self._hist = hist
85        self._response_callbacks = []
86        self._concurrency = config.outstanding_rpcs_per_channel
87
88    async def run(self) -> None:
89        await self._channel.channel_ready()
90
91    async def stop(self) -> None:
92        await self._channel.close()
93
94    def _record_query_time(self, query_time: float) -> None:
95        self._hist.add(query_time * 1e9)
96
97
98class UnaryAsyncBenchmarkClient(BenchmarkClient):
99
100    def __init__(self, address: str, config: control_pb2.ClientConfig,
101                 hist: histogram.Histogram):
102        super().__init__(address, config, hist)
103        self._running = None
104        self._stopped = asyncio.Event()
105
106    async def _send_request(self):
107        start_time = time.monotonic()
108        await self._stub.UnaryCall(self._request)
109        self._record_query_time(time.monotonic() - start_time)
110
111    async def _send_indefinitely(self) -> None:
112        while self._running:
113            await self._send_request()
114
115    async def run(self) -> None:
116        await super().run()
117        self._running = True
118        senders = (self._send_indefinitely() for _ in range(self._concurrency))
119        await asyncio.gather(*senders)
120        self._stopped.set()
121
122    async def stop(self) -> None:
123        self._running = False
124        await self._stopped.wait()
125        await super().stop()
126
127
128class StreamingAsyncBenchmarkClient(BenchmarkClient):
129
130    def __init__(self, address: str, config: control_pb2.ClientConfig,
131                 hist: histogram.Histogram):
132        super().__init__(address, config, hist)
133        self._running = None
134        self._stopped = asyncio.Event()
135
136    async def _one_streaming_call(self):
137        call = self._stub.StreamingCall()
138        while self._running:
139            start_time = time.time()
140            await call.write(self._request)
141            await call.read()
142            self._record_query_time(time.time() - start_time)
143        await call.done_writing()
144
145    async def run(self):
146        await super().run()
147        self._running = True
148        senders = (self._one_streaming_call() for _ in range(self._concurrency))
149        await asyncio.gather(*senders)
150        self._stopped.set()
151
152    async def stop(self):
153        self._running = False
154        await self._stopped.wait()
155        await super().stop()
156