# Copyright 2020 The gRPC Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The Python AsyncIO Benchmark Clients.""" import abc import asyncio import time import logging import random import grpc from grpc.experimental import aio from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, messages_pb2) from tests.qps import histogram from tests.unit import resources class GenericStub(object): def __init__(self, channel: aio.Channel): self.UnaryCall = channel.unary_unary( '/grpc.testing.BenchmarkService/UnaryCall') self.StreamingFromServer = channel.unary_stream( '/grpc.testing.BenchmarkService/StreamingFromServer') self.StreamingCall = channel.stream_stream( '/grpc.testing.BenchmarkService/StreamingCall') class BenchmarkClient(abc.ABC): """Benchmark client interface that exposes a non-blocking send_request().""" def __init__(self, address: str, config: control_pb2.ClientConfig, hist: histogram.Histogram): # Disables underlying reuse of subchannels unique_option = (('iv', random.random()),) # Parses the channel argument from config channel_args = tuple( (arg.name, arg.str_value) if arg.HasField('str_value') else ( arg.name, int(arg.int_value)) for arg in config.channel_args) # Creates the channel if config.HasField('security_params'): channel_credentials = grpc.ssl_channel_credentials( resources.test_root_certificates(),) server_host_override_option = (( 'grpc.ssl_target_name_override', config.security_params.server_host_override, ),) self._channel = aio.secure_channel( address, channel_credentials, unique_option + channel_args + server_host_override_option) else: self._channel = aio.insecure_channel(address, options=unique_option + channel_args) # Creates the stub if config.payload_config.WhichOneof('payload') == 'simple_params': self._generic = False self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( self._channel) payload = messages_pb2.Payload( body=b'\0' * config.payload_config.simple_params.req_size) self._request = messages_pb2.SimpleRequest( payload=payload, response_size=config.payload_config.simple_params.resp_size) else: self._generic = True self._stub = GenericStub(self._channel) self._request = b'\0' * config.payload_config.bytebuf_params.req_size self._hist = hist self._response_callbacks = [] self._concurrency = config.outstanding_rpcs_per_channel async def run(self) -> None: await self._channel.channel_ready() async def stop(self) -> None: await self._channel.close() def _record_query_time(self, query_time: float) -> None: self._hist.add(query_time * 1e9) class UnaryAsyncBenchmarkClient(BenchmarkClient): def __init__(self, address: str, config: control_pb2.ClientConfig, hist: histogram.Histogram): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() async def _send_request(self): start_time = time.monotonic() await self._stub.UnaryCall(self._request) self._record_query_time(time.monotonic() - start_time) async def _send_indefinitely(self) -> None: while self._running: await self._send_request() async def run(self) -> None: await super().run() self._running = True senders = (self._send_indefinitely() for _ in range(self._concurrency)) await asyncio.gather(*senders) self._stopped.set() async def stop(self) -> None: self._running = False await self._stopped.wait() await super().stop() class StreamingAsyncBenchmarkClient(BenchmarkClient): def __init__(self, address: str, config: control_pb2.ClientConfig, hist: histogram.Histogram): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() async def _one_streaming_call(self): call = self._stub.StreamingCall() while self._running: start_time = time.time() await call.write(self._request) await call.read() self._record_query_time(time.time() - start_time) await call.done_writing() async def run(self): await super().run() self._running = True senders = (self._one_streaming_call() for _ in range(self._concurrency)) await asyncio.gather(*senders) self._stopped.set() async def stop(self): self._running = False await self._stopped.wait() await super().stop() class ServerStreamingAsyncBenchmarkClient(BenchmarkClient): def __init__(self, address: str, config: control_pb2.ClientConfig, hist: histogram.Histogram): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() async def _one_server_streaming_call(self): call = self._stub.StreamingFromServer(self._request) while self._running: start_time = time.time() await call.read() self._record_query_time(time.time() - start_time) async def run(self): await super().run() self._running = True senders = ( self._one_server_streaming_call() for _ in range(self._concurrency)) await asyncio.gather(*senders) self._stopped.set() async def stop(self): self._running = False await self._stopped.wait() await super().stop()