1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""The Python AsyncIO Benchmark Clients.""" 15 16import abc 17import asyncio 18import logging 19import random 20import time 21 22import grpc 23from grpc.experimental import aio 24 25from src.proto.grpc.testing import benchmark_service_pb2_grpc 26from src.proto.grpc.testing import control_pb2 27from src.proto.grpc.testing import messages_pb2 28from tests.qps import histogram 29from tests.unit import resources 30 31 32class GenericStub(object): 33 def __init__(self, channel: aio.Channel): 34 self.UnaryCall = channel.unary_unary( 35 "/grpc.testing.BenchmarkService/UnaryCall" 36 ) 37 self.StreamingFromServer = channel.unary_stream( 38 "/grpc.testing.BenchmarkService/StreamingFromServer" 39 ) 40 self.StreamingCall = channel.stream_stream( 41 "/grpc.testing.BenchmarkService/StreamingCall" 42 ) 43 44 45class BenchmarkClient(abc.ABC): 46 """Benchmark client interface that exposes a non-blocking send_request().""" 47 48 def __init__( 49 self, 50 address: str, 51 config: control_pb2.ClientConfig, 52 hist: histogram.Histogram, 53 ): 54 # Disables underlying reuse of subchannels 55 unique_option = (("iv", random.random()),) 56 57 # Parses the channel argument from config 58 channel_args = tuple( 59 (arg.name, arg.str_value) 60 if arg.HasField("str_value") 61 else (arg.name, int(arg.int_value)) 62 for arg in config.channel_args 63 ) 64 65 # Creates the channel 66 if config.HasField("security_params"): 67 channel_credentials = grpc.ssl_channel_credentials( 68 resources.test_root_certificates(), 69 ) 70 server_host_override_option = ( 71 ( 72 "grpc.ssl_target_name_override", 73 config.security_params.server_host_override, 74 ), 75 ) 76 self._channel = aio.secure_channel( 77 address, 78 channel_credentials, 79 unique_option + channel_args + server_host_override_option, 80 ) 81 else: 82 self._channel = aio.insecure_channel( 83 address, options=unique_option + channel_args 84 ) 85 86 # Creates the stub 87 if config.payload_config.WhichOneof("payload") == "simple_params": 88 self._generic = False 89 self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( 90 self._channel 91 ) 92 payload = messages_pb2.Payload( 93 body=b"\0" * config.payload_config.simple_params.req_size 94 ) 95 self._request = messages_pb2.SimpleRequest( 96 payload=payload, 97 response_size=config.payload_config.simple_params.resp_size, 98 ) 99 else: 100 self._generic = True 101 self._stub = GenericStub(self._channel) 102 self._request = ( 103 b"\0" * config.payload_config.bytebuf_params.req_size 104 ) 105 106 self._hist = hist 107 self._response_callbacks = [] 108 self._concurrency = config.outstanding_rpcs_per_channel 109 110 async def run(self) -> None: 111 await self._channel.channel_ready() 112 113 async def stop(self) -> None: 114 await self._channel.close() 115 116 def _record_query_time(self, query_time: float) -> None: 117 self._hist.add(query_time * 1e9) 118 119 120class UnaryAsyncBenchmarkClient(BenchmarkClient): 121 def __init__( 122 self, 123 address: str, 124 config: control_pb2.ClientConfig, 125 hist: histogram.Histogram, 126 ): 127 super().__init__(address, config, hist) 128 self._running = None 129 self._stopped = asyncio.Event() 130 131 async def _send_request(self): 132 start_time = time.monotonic() 133 await self._stub.UnaryCall(self._request) 134 self._record_query_time(time.monotonic() - start_time) 135 136 async def _send_indefinitely(self) -> None: 137 while self._running: 138 await self._send_request() 139 140 async def run(self) -> None: 141 await super().run() 142 self._running = True 143 senders = (self._send_indefinitely() for _ in range(self._concurrency)) 144 await asyncio.gather(*senders) 145 self._stopped.set() 146 147 async def stop(self) -> None: 148 self._running = False 149 await self._stopped.wait() 150 await super().stop() 151 152 153class StreamingAsyncBenchmarkClient(BenchmarkClient): 154 def __init__( 155 self, 156 address: str, 157 config: control_pb2.ClientConfig, 158 hist: histogram.Histogram, 159 ): 160 super().__init__(address, config, hist) 161 self._running = None 162 self._stopped = asyncio.Event() 163 164 async def _one_streaming_call(self): 165 call = self._stub.StreamingCall() 166 while self._running: 167 start_time = time.time() 168 await call.write(self._request) 169 await call.read() 170 self._record_query_time(time.time() - start_time) 171 await call.done_writing() 172 173 async def run(self): 174 await super().run() 175 self._running = True 176 senders = (self._one_streaming_call() for _ in range(self._concurrency)) 177 await asyncio.gather(*senders) 178 self._stopped.set() 179 180 async def stop(self): 181 self._running = False 182 await self._stopped.wait() 183 await super().stop() 184 185 186class ServerStreamingAsyncBenchmarkClient(BenchmarkClient): 187 def __init__( 188 self, 189 address: str, 190 config: control_pb2.ClientConfig, 191 hist: histogram.Histogram, 192 ): 193 super().__init__(address, config, hist) 194 self._running = None 195 self._stopped = asyncio.Event() 196 197 async def _one_server_streaming_call(self): 198 call = self._stub.StreamingFromServer(self._request) 199 while self._running: 200 start_time = time.time() 201 await call.read() 202 self._record_query_time(time.time() - start_time) 203 204 async def run(self): 205 await super().run() 206 self._running = True 207 senders = ( 208 self._one_server_streaming_call() for _ in range(self._concurrency) 209 ) 210 await asyncio.gather(*senders) 211 self._stopped.set() 212 213 async def stop(self): 214 self._running = False 215 await self._stopped.wait() 216 await super().stop() 217