1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""The Python AsyncIO Benchmark Clients.""" 15 16import abc 17import asyncio 18import time 19import logging 20import random 21 22import grpc 23from grpc.experimental import aio 24 25from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, 26 messages_pb2) 27from tests.qps import histogram 28from tests.unit import resources 29 30 31class GenericStub(object): 32 33 def __init__(self, channel: aio.Channel): 34 self.UnaryCall = channel.unary_unary( 35 '/grpc.testing.BenchmarkService/UnaryCall') 36 self.StreamingCall = channel.stream_stream( 37 '/grpc.testing.BenchmarkService/StreamingCall') 38 39 40class BenchmarkClient(abc.ABC): 41 """Benchmark client interface that exposes a non-blocking send_request().""" 42 43 def __init__(self, address: str, config: control_pb2.ClientConfig, 44 hist: histogram.Histogram): 45 # Disables underlying reuse of subchannels 46 unique_option = (('iv', random.random()),) 47 48 # Parses the channel argument from config 49 channel_args = tuple( 50 (arg.name, arg.str_value) if arg.HasField('str_value') else ( 51 arg.name, int(arg.int_value)) for arg in config.channel_args) 52 53 # Creates the channel 54 if config.HasField('security_params'): 55 channel_credentials = grpc.ssl_channel_credentials( 56 resources.test_root_certificates(),) 57 server_host_override_option = (( 58 'grpc.ssl_target_name_override', 59 config.security_params.server_host_override, 60 ),) 61 self._channel = aio.secure_channel( 62 address, channel_credentials, 63 unique_option + channel_args + server_host_override_option) 64 else: 65 self._channel = aio.insecure_channel(address, 66 options=unique_option + 67 channel_args) 68 69 # Creates the stub 70 if config.payload_config.WhichOneof('payload') == 'simple_params': 71 self._generic = False 72 self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( 73 self._channel) 74 payload = messages_pb2.Payload( 75 body=b'\0' * config.payload_config.simple_params.req_size) 76 self._request = messages_pb2.SimpleRequest( 77 payload=payload, 78 response_size=config.payload_config.simple_params.resp_size) 79 else: 80 self._generic = True 81 self._stub = GenericStub(self._channel) 82 self._request = b'\0' * config.payload_config.bytebuf_params.req_size 83 84 self._hist = hist 85 self._response_callbacks = [] 86 self._concurrency = config.outstanding_rpcs_per_channel 87 88 async def run(self) -> None: 89 await self._channel.channel_ready() 90 91 async def stop(self) -> None: 92 await self._channel.close() 93 94 def _record_query_time(self, query_time: float) -> None: 95 self._hist.add(query_time * 1e9) 96 97 98class UnaryAsyncBenchmarkClient(BenchmarkClient): 99 100 def __init__(self, address: str, config: control_pb2.ClientConfig, 101 hist: histogram.Histogram): 102 super().__init__(address, config, hist) 103 self._running = None 104 self._stopped = asyncio.Event() 105 106 async def _send_request(self): 107 start_time = time.monotonic() 108 await self._stub.UnaryCall(self._request) 109 self._record_query_time(time.monotonic() - start_time) 110 111 async def _send_indefinitely(self) -> None: 112 while self._running: 113 await self._send_request() 114 115 async def run(self) -> None: 116 await super().run() 117 self._running = True 118 senders = (self._send_indefinitely() for _ in range(self._concurrency)) 119 await asyncio.gather(*senders) 120 self._stopped.set() 121 122 async def stop(self) -> None: 123 self._running = False 124 await self._stopped.wait() 125 await super().stop() 126 127 128class StreamingAsyncBenchmarkClient(BenchmarkClient): 129 130 def __init__(self, address: str, config: control_pb2.ClientConfig, 131 hist: histogram.Histogram): 132 super().__init__(address, config, hist) 133 self._running = None 134 self._stopped = asyncio.Event() 135 136 async def _one_streaming_call(self): 137 call = self._stub.StreamingCall() 138 while self._running: 139 start_time = time.time() 140 await call.write(self._request) 141 await call.read() 142 self._record_query_time(time.time() - start_time) 143 await call.done_writing() 144 145 async def run(self): 146 await super().run() 147 self._running = True 148 senders = (self._one_streaming_call() for _ in range(self._concurrency)) 149 await asyncio.gather(*senders) 150 self._stopped.set() 151 152 async def stop(self): 153 self._running = False 154 await self._stopped.wait() 155 await super().stop() 156