1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC).""" 15 16import abc 17import threading 18import time 19 20from concurrent import futures 21from six.moves import queue 22 23import grpc 24from src.proto.grpc.testing import messages_pb2 25from src.proto.grpc.testing import benchmark_service_pb2_grpc 26from tests.unit import resources 27from tests.unit import test_common 28 29_TIMEOUT = 60 * 60 * 24 30 31 32class GenericStub(object): 33 34 def __init__(self, channel): 35 self.UnaryCall = channel.unary_unary( 36 '/grpc.testing.BenchmarkService/UnaryCall') 37 self.StreamingFromServer = channel.unary_stream( 38 '/grpc.testing.BenchmarkService/StreamingFromServer') 39 self.StreamingCall = channel.stream_stream( 40 '/grpc.testing.BenchmarkService/StreamingCall') 41 42 43class BenchmarkClient: 44 """Benchmark client interface that exposes a non-blocking send_request().""" 45 46 __metaclass__ = abc.ABCMeta 47 48 def __init__(self, server, config, hist): 49 # Create the stub 50 if config.HasField('security_params'): 51 creds = grpc.ssl_channel_credentials( 52 resources.test_root_certificates()) 53 channel = test_common.test_secure_channel( 54 server, creds, config.security_params.server_host_override) 55 else: 56 channel = grpc.insecure_channel(server) 57 58 # waits for the channel to be ready before we start sending messages 59 grpc.channel_ready_future(channel).result() 60 61 if config.payload_config.WhichOneof('payload') == 'simple_params': 62 self._generic = False 63 self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( 64 channel) 65 payload = messages_pb2.Payload( 66 body=bytes(b'\0' * 67 config.payload_config.simple_params.req_size)) 68 self._request = messages_pb2.SimpleRequest( 69 payload=payload, 70 response_size=config.payload_config.simple_params.resp_size) 71 else: 72 self._generic = True 73 self._stub = GenericStub(channel) 74 self._request = bytes(b'\0' * 75 config.payload_config.bytebuf_params.req_size) 76 77 self._hist = hist 78 self._response_callbacks = [] 79 80 def add_response_callback(self, callback): 81 """callback will be invoked as callback(client, query_time)""" 82 self._response_callbacks.append(callback) 83 84 @abc.abstractmethod 85 def send_request(self): 86 """Non-blocking wrapper for a client's request operation.""" 87 raise NotImplementedError() 88 89 def start(self): 90 pass 91 92 def stop(self): 93 pass 94 95 def _handle_response(self, client, query_time): 96 self._hist.add(query_time * 1e9) # Report times in nanoseconds 97 for callback in self._response_callbacks: 98 callback(client, query_time) 99 100 101class UnarySyncBenchmarkClient(BenchmarkClient): 102 103 def __init__(self, server, config, hist): 104 super(UnarySyncBenchmarkClient, self).__init__(server, config, hist) 105 self._pool = futures.ThreadPoolExecutor( 106 max_workers=config.outstanding_rpcs_per_channel) 107 108 def send_request(self): 109 # Send requests in separate threads to support multiple outstanding rpcs 110 # (See src/proto/grpc/testing/control.proto) 111 self._pool.submit(self._dispatch_request) 112 113 def stop(self): 114 self._pool.shutdown(wait=True) 115 self._stub = None 116 117 def _dispatch_request(self): 118 start_time = time.time() 119 self._stub.UnaryCall(self._request, _TIMEOUT) 120 end_time = time.time() 121 self._handle_response(self, end_time - start_time) 122 123 124class UnaryAsyncBenchmarkClient(BenchmarkClient): 125 126 def send_request(self): 127 # Use the Future callback api to support multiple outstanding rpcs 128 start_time = time.time() 129 response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT) 130 response_future.add_done_callback( 131 lambda resp: self._response_received(start_time, resp)) 132 133 def _response_received(self, start_time, resp): 134 resp.result() 135 end_time = time.time() 136 self._handle_response(self, end_time - start_time) 137 138 def stop(self): 139 self._stub = None 140 141 142class _SyncStream(object): 143 144 def __init__(self, stub, generic, request, handle_response): 145 self._stub = stub 146 self._generic = generic 147 self._request = request 148 self._handle_response = handle_response 149 self._is_streaming = False 150 self._request_queue = queue.Queue() 151 self._send_time_queue = queue.Queue() 152 153 def send_request(self): 154 self._send_time_queue.put(time.time()) 155 self._request_queue.put(self._request) 156 157 def start(self): 158 self._is_streaming = True 159 response_stream = self._stub.StreamingCall(self._request_generator(), 160 _TIMEOUT) 161 for _ in response_stream: 162 self._handle_response( 163 self, 164 time.time() - self._send_time_queue.get_nowait()) 165 166 def stop(self): 167 self._is_streaming = False 168 169 def _request_generator(self): 170 while self._is_streaming: 171 try: 172 request = self._request_queue.get(block=True, timeout=1.0) 173 yield request 174 except queue.Empty: 175 pass 176 177 178class StreamingSyncBenchmarkClient(BenchmarkClient): 179 180 def __init__(self, server, config, hist): 181 super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist) 182 self._pool = futures.ThreadPoolExecutor( 183 max_workers=config.outstanding_rpcs_per_channel) 184 self._streams = [ 185 _SyncStream(self._stub, self._generic, self._request, 186 self._handle_response) 187 for _ in range(config.outstanding_rpcs_per_channel) 188 ] 189 self._curr_stream = 0 190 191 def send_request(self): 192 # Use a round_robin scheduler to determine what stream to send on 193 self._streams[self._curr_stream].send_request() 194 self._curr_stream = (self._curr_stream + 1) % len(self._streams) 195 196 def start(self): 197 for stream in self._streams: 198 self._pool.submit(stream.start) 199 200 def stop(self): 201 for stream in self._streams: 202 stream.stop() 203 self._pool.shutdown(wait=True) 204 self._stub = None 205 206 207class ServerStreamingSyncBenchmarkClient(BenchmarkClient): 208 209 def __init__(self, server, config, hist): 210 super(ServerStreamingSyncBenchmarkClient, 211 self).__init__(server, config, hist) 212 if config.outstanding_rpcs_per_channel == 1: 213 self._pool = None 214 else: 215 self._pool = futures.ThreadPoolExecutor( 216 max_workers=config.outstanding_rpcs_per_channel) 217 self._rpcs = [] 218 self._sender = None 219 220 def send_request(self): 221 if self._pool is None: 222 self._sender = threading.Thread( 223 target=self._one_stream_streaming_rpc, daemon=True) 224 self._sender.start() 225 else: 226 self._pool.submit(self._one_stream_streaming_rpc) 227 228 def _one_stream_streaming_rpc(self): 229 response_stream = self._stub.StreamingFromServer( 230 self._request, _TIMEOUT) 231 self._rpcs.append(response_stream) 232 start_time = time.time() 233 for _ in response_stream: 234 self._handle_response(self, time.time() - start_time) 235 start_time = time.time() 236 237 def stop(self): 238 for call in self._rpcs: 239 call.cancel() 240 if self._sender is not None: 241 self._sender.join() 242 if self._pool is not None: 243 self._pool.shutdown(wait=False) 244 self._stub = None 245