1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""The Python AsyncIO Benchmark Clients.""" 15 16import abc 17import asyncio 18import time 19import logging 20import random 21 22import grpc 23from grpc.experimental import aio 24 25from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, 26 messages_pb2) 27from tests.qps import histogram 28from tests.unit import resources 29 30 31class GenericStub(object): 32 33 def __init__(self, channel: aio.Channel): 34 self.UnaryCall = channel.unary_unary( 35 '/grpc.testing.BenchmarkService/UnaryCall') 36 self.StreamingFromServer = channel.unary_stream( 37 '/grpc.testing.BenchmarkService/StreamingFromServer') 38 self.StreamingCall = channel.stream_stream( 39 '/grpc.testing.BenchmarkService/StreamingCall') 40 41 42class BenchmarkClient(abc.ABC): 43 """Benchmark client interface that exposes a non-blocking send_request().""" 44 45 def __init__(self, address: str, config: control_pb2.ClientConfig, 46 hist: histogram.Histogram): 47 # Disables underlying reuse of subchannels 48 unique_option = (('iv', random.random()),) 49 50 # Parses the channel argument from config 51 channel_args = tuple( 52 (arg.name, arg.str_value) if arg.HasField('str_value') else ( 53 arg.name, int(arg.int_value)) for arg in config.channel_args) 54 55 # Creates the channel 56 if config.HasField('security_params'): 57 channel_credentials = grpc.ssl_channel_credentials( 58 resources.test_root_certificates(),) 59 server_host_override_option = (( 60 'grpc.ssl_target_name_override', 61 config.security_params.server_host_override, 62 ),) 63 self._channel = aio.secure_channel( 64 address, channel_credentials, 65 unique_option + channel_args + server_host_override_option) 66 else: 67 self._channel = aio.insecure_channel(address, 68 options=unique_option + 69 channel_args) 70 71 # Creates the stub 72 if config.payload_config.WhichOneof('payload') == 'simple_params': 73 self._generic = False 74 self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( 75 self._channel) 76 payload = messages_pb2.Payload( 77 body=b'\0' * config.payload_config.simple_params.req_size) 78 self._request = messages_pb2.SimpleRequest( 79 payload=payload, 80 response_size=config.payload_config.simple_params.resp_size) 81 else: 82 self._generic = True 83 self._stub = GenericStub(self._channel) 84 self._request = b'\0' * config.payload_config.bytebuf_params.req_size 85 86 self._hist = hist 87 self._response_callbacks = [] 88 self._concurrency = config.outstanding_rpcs_per_channel 89 90 async def run(self) -> None: 91 await self._channel.channel_ready() 92 93 async def stop(self) -> None: 94 await self._channel.close() 95 96 def _record_query_time(self, query_time: float) -> None: 97 self._hist.add(query_time * 1e9) 98 99 100class UnaryAsyncBenchmarkClient(BenchmarkClient): 101 102 def __init__(self, address: str, config: control_pb2.ClientConfig, 103 hist: histogram.Histogram): 104 super().__init__(address, config, hist) 105 self._running = None 106 self._stopped = asyncio.Event() 107 108 async def _send_request(self): 109 start_time = time.monotonic() 110 await self._stub.UnaryCall(self._request) 111 self._record_query_time(time.monotonic() - start_time) 112 113 async def _send_indefinitely(self) -> None: 114 while self._running: 115 await self._send_request() 116 117 async def run(self) -> None: 118 await super().run() 119 self._running = True 120 senders = (self._send_indefinitely() for _ in range(self._concurrency)) 121 await asyncio.gather(*senders) 122 self._stopped.set() 123 124 async def stop(self) -> None: 125 self._running = False 126 await self._stopped.wait() 127 await super().stop() 128 129 130class StreamingAsyncBenchmarkClient(BenchmarkClient): 131 132 def __init__(self, address: str, config: control_pb2.ClientConfig, 133 hist: histogram.Histogram): 134 super().__init__(address, config, hist) 135 self._running = None 136 self._stopped = asyncio.Event() 137 138 async def _one_streaming_call(self): 139 call = self._stub.StreamingCall() 140 while self._running: 141 start_time = time.time() 142 await call.write(self._request) 143 await call.read() 144 self._record_query_time(time.time() - start_time) 145 await call.done_writing() 146 147 async def run(self): 148 await super().run() 149 self._running = True 150 senders = (self._one_streaming_call() for _ in range(self._concurrency)) 151 await asyncio.gather(*senders) 152 self._stopped.set() 153 154 async def stop(self): 155 self._running = False 156 await self._stopped.wait() 157 await super().stop() 158 159 160class ServerStreamingAsyncBenchmarkClient(BenchmarkClient): 161 162 def __init__(self, address: str, config: control_pb2.ClientConfig, 163 hist: histogram.Histogram): 164 super().__init__(address, config, hist) 165 self._running = None 166 self._stopped = asyncio.Event() 167 168 async def _one_server_streaming_call(self): 169 call = self._stub.StreamingFromServer(self._request) 170 while self._running: 171 start_time = time.time() 172 await call.read() 173 self._record_query_time(time.time() - start_time) 174 175 async def run(self): 176 await super().run() 177 self._running = True 178 senders = ( 179 self._one_server_streaming_call() for _ in range(self._concurrency)) 180 await asyncio.gather(*senders) 181 self._stopped.set() 182 183 async def stop(self): 184 self._running = False 185 await self._stopped.wait() 186 await super().stop() 187