• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python AsyncIO Benchmark Clients."""
15
16import abc
17import asyncio
18import time
19import logging
20import random
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2,
26                                    messages_pb2)
27from tests.qps import histogram
28from tests.unit import resources
29
30
31class GenericStub(object):
32
33    def __init__(self, channel: aio.Channel):
34        self.UnaryCall = channel.unary_unary(
35            '/grpc.testing.BenchmarkService/UnaryCall')
36        self.StreamingFromServer = channel.unary_stream(
37            '/grpc.testing.BenchmarkService/StreamingFromServer')
38        self.StreamingCall = channel.stream_stream(
39            '/grpc.testing.BenchmarkService/StreamingCall')
40
41
42class BenchmarkClient(abc.ABC):
43    """Benchmark client interface that exposes a non-blocking send_request()."""
44
45    def __init__(self, address: str, config: control_pb2.ClientConfig,
46                 hist: histogram.Histogram):
47        # Disables underlying reuse of subchannels
48        unique_option = (('iv', random.random()),)
49
50        # Parses the channel argument from config
51        channel_args = tuple(
52            (arg.name, arg.str_value) if arg.HasField('str_value') else (
53                arg.name, int(arg.int_value)) for arg in config.channel_args)
54
55        # Creates the channel
56        if config.HasField('security_params'):
57            channel_credentials = grpc.ssl_channel_credentials(
58                resources.test_root_certificates(),)
59            server_host_override_option = ((
60                'grpc.ssl_target_name_override',
61                config.security_params.server_host_override,
62            ),)
63            self._channel = aio.secure_channel(
64                address, channel_credentials,
65                unique_option + channel_args + server_host_override_option)
66        else:
67            self._channel = aio.insecure_channel(address,
68                                                 options=unique_option +
69                                                 channel_args)
70
71        # Creates the stub
72        if config.payload_config.WhichOneof('payload') == 'simple_params':
73            self._generic = False
74            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
75                self._channel)
76            payload = messages_pb2.Payload(
77                body=b'\0' * config.payload_config.simple_params.req_size)
78            self._request = messages_pb2.SimpleRequest(
79                payload=payload,
80                response_size=config.payload_config.simple_params.resp_size)
81        else:
82            self._generic = True
83            self._stub = GenericStub(self._channel)
84            self._request = b'\0' * config.payload_config.bytebuf_params.req_size
85
86        self._hist = hist
87        self._response_callbacks = []
88        self._concurrency = config.outstanding_rpcs_per_channel
89
90    async def run(self) -> None:
91        await self._channel.channel_ready()
92
93    async def stop(self) -> None:
94        await self._channel.close()
95
96    def _record_query_time(self, query_time: float) -> None:
97        self._hist.add(query_time * 1e9)
98
99
100class UnaryAsyncBenchmarkClient(BenchmarkClient):
101
102    def __init__(self, address: str, config: control_pb2.ClientConfig,
103                 hist: histogram.Histogram):
104        super().__init__(address, config, hist)
105        self._running = None
106        self._stopped = asyncio.Event()
107
108    async def _send_request(self):
109        start_time = time.monotonic()
110        await self._stub.UnaryCall(self._request)
111        self._record_query_time(time.monotonic() - start_time)
112
113    async def _send_indefinitely(self) -> None:
114        while self._running:
115            await self._send_request()
116
117    async def run(self) -> None:
118        await super().run()
119        self._running = True
120        senders = (self._send_indefinitely() for _ in range(self._concurrency))
121        await asyncio.gather(*senders)
122        self._stopped.set()
123
124    async def stop(self) -> None:
125        self._running = False
126        await self._stopped.wait()
127        await super().stop()
128
129
130class StreamingAsyncBenchmarkClient(BenchmarkClient):
131
132    def __init__(self, address: str, config: control_pb2.ClientConfig,
133                 hist: histogram.Histogram):
134        super().__init__(address, config, hist)
135        self._running = None
136        self._stopped = asyncio.Event()
137
138    async def _one_streaming_call(self):
139        call = self._stub.StreamingCall()
140        while self._running:
141            start_time = time.time()
142            await call.write(self._request)
143            await call.read()
144            self._record_query_time(time.time() - start_time)
145        await call.done_writing()
146
147    async def run(self):
148        await super().run()
149        self._running = True
150        senders = (self._one_streaming_call() for _ in range(self._concurrency))
151        await asyncio.gather(*senders)
152        self._stopped.set()
153
154    async def stop(self):
155        self._running = False
156        await self._stopped.wait()
157        await super().stop()
158
159
160class ServerStreamingAsyncBenchmarkClient(BenchmarkClient):
161
162    def __init__(self, address: str, config: control_pb2.ClientConfig,
163                 hist: histogram.Histogram):
164        super().__init__(address, config, hist)
165        self._running = None
166        self._stopped = asyncio.Event()
167
168    async def _one_server_streaming_call(self):
169        call = self._stub.StreamingFromServer(self._request)
170        while self._running:
171            start_time = time.time()
172            await call.read()
173            self._record_query_time(time.time() - start_time)
174
175    async def run(self):
176        await super().run()
177        self._running = True
178        senders = (
179            self._one_server_streaming_call() for _ in range(self._concurrency))
180        await asyncio.gather(*senders)
181        self._stopped.set()
182
183    async def stop(self):
184        self._running = False
185        await self._stopped.wait()
186        await super().stop()
187