• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python AsyncIO Benchmark Clients."""
15
16import abc
17import asyncio
18import logging
19import random
20import time
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import benchmark_service_pb2_grpc
26from src.proto.grpc.testing import control_pb2
27from src.proto.grpc.testing import messages_pb2
28from tests.qps import histogram
29from tests.unit import resources
30
31
32class GenericStub(object):
33    def __init__(self, channel: aio.Channel):
34        self.UnaryCall = channel.unary_unary(
35            "/grpc.testing.BenchmarkService/UnaryCall"
36        )
37        self.StreamingFromServer = channel.unary_stream(
38            "/grpc.testing.BenchmarkService/StreamingFromServer"
39        )
40        self.StreamingCall = channel.stream_stream(
41            "/grpc.testing.BenchmarkService/StreamingCall"
42        )
43
44
45class BenchmarkClient(abc.ABC):
46    """Benchmark client interface that exposes a non-blocking send_request()."""
47
48    def __init__(
49        self,
50        address: str,
51        config: control_pb2.ClientConfig,
52        hist: histogram.Histogram,
53    ):
54        # Disables underlying reuse of subchannels
55        unique_option = (("iv", random.random()),)
56
57        # Parses the channel argument from config
58        channel_args = tuple(
59            (arg.name, arg.str_value)
60            if arg.HasField("str_value")
61            else (arg.name, int(arg.int_value))
62            for arg in config.channel_args
63        )
64
65        # Creates the channel
66        if config.HasField("security_params"):
67            channel_credentials = grpc.ssl_channel_credentials(
68                resources.test_root_certificates(),
69            )
70            server_host_override_option = (
71                (
72                    "grpc.ssl_target_name_override",
73                    config.security_params.server_host_override,
74                ),
75            )
76            self._channel = aio.secure_channel(
77                address,
78                channel_credentials,
79                unique_option + channel_args + server_host_override_option,
80            )
81        else:
82            self._channel = aio.insecure_channel(
83                address, options=unique_option + channel_args
84            )
85
86        # Creates the stub
87        if config.payload_config.WhichOneof("payload") == "simple_params":
88            self._generic = False
89            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
90                self._channel
91            )
92            payload = messages_pb2.Payload(
93                body=b"\0" * config.payload_config.simple_params.req_size
94            )
95            self._request = messages_pb2.SimpleRequest(
96                payload=payload,
97                response_size=config.payload_config.simple_params.resp_size,
98            )
99        else:
100            self._generic = True
101            self._stub = GenericStub(self._channel)
102            self._request = (
103                b"\0" * config.payload_config.bytebuf_params.req_size
104            )
105
106        self._hist = hist
107        self._response_callbacks = []
108        self._concurrency = config.outstanding_rpcs_per_channel
109
110    async def run(self) -> None:
111        await self._channel.channel_ready()
112
113    async def stop(self) -> None:
114        await self._channel.close()
115
116    def _record_query_time(self, query_time: float) -> None:
117        self._hist.add(query_time * 1e9)
118
119
120class UnaryAsyncBenchmarkClient(BenchmarkClient):
121    def __init__(
122        self,
123        address: str,
124        config: control_pb2.ClientConfig,
125        hist: histogram.Histogram,
126    ):
127        super().__init__(address, config, hist)
128        self._running = None
129        self._stopped = asyncio.Event()
130
131    async def _send_request(self):
132        start_time = time.monotonic()
133        await self._stub.UnaryCall(self._request)
134        self._record_query_time(time.monotonic() - start_time)
135
136    async def _send_indefinitely(self) -> None:
137        while self._running:
138            await self._send_request()
139
140    async def run(self) -> None:
141        await super().run()
142        self._running = True
143        senders = (self._send_indefinitely() for _ in range(self._concurrency))
144        await asyncio.gather(*senders)
145        self._stopped.set()
146
147    async def stop(self) -> None:
148        self._running = False
149        await self._stopped.wait()
150        await super().stop()
151
152
153class StreamingAsyncBenchmarkClient(BenchmarkClient):
154    def __init__(
155        self,
156        address: str,
157        config: control_pb2.ClientConfig,
158        hist: histogram.Histogram,
159    ):
160        super().__init__(address, config, hist)
161        self._running = None
162        self._stopped = asyncio.Event()
163
164    async def _one_streaming_call(self):
165        call = self._stub.StreamingCall()
166        while self._running:
167            start_time = time.time()
168            await call.write(self._request)
169            await call.read()
170            self._record_query_time(time.time() - start_time)
171        await call.done_writing()
172
173    async def run(self):
174        await super().run()
175        self._running = True
176        senders = (self._one_streaming_call() for _ in range(self._concurrency))
177        await asyncio.gather(*senders)
178        self._stopped.set()
179
180    async def stop(self):
181        self._running = False
182        await self._stopped.wait()
183        await super().stop()
184
185
186class ServerStreamingAsyncBenchmarkClient(BenchmarkClient):
187    def __init__(
188        self,
189        address: str,
190        config: control_pb2.ClientConfig,
191        hist: histogram.Histogram,
192    ):
193        super().__init__(address, config, hist)
194        self._running = None
195        self._stopped = asyncio.Event()
196
197    async def _one_server_streaming_call(self):
198        call = self._stub.StreamingFromServer(self._request)
199        while self._running:
200            start_time = time.time()
201            await call.read()
202            self._record_query_time(time.time() - start_time)
203
204    async def run(self):
205        await super().run()
206        self._running = True
207        senders = (
208            self._one_server_streaming_call() for _ in range(self._concurrency)
209        )
210        await asyncio.gather(*senders)
211        self._stopped.set()
212
213    async def stop(self):
214        self._running = False
215        await self._stopped.wait()
216        await super().stop()
217