• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
15
16import abc
17import threading
18import time
19
20from concurrent import futures
21from six.moves import queue
22
23import grpc
24from src.proto.grpc.testing import messages_pb2
25from src.proto.grpc.testing import benchmark_service_pb2_grpc
26from tests.unit import resources
27from tests.unit import test_common
28
29_TIMEOUT = 60 * 60 * 24
30
31
32class GenericStub(object):
33
34    def __init__(self, channel):
35        self.UnaryCall = channel.unary_unary(
36            '/grpc.testing.BenchmarkService/UnaryCall')
37        self.StreamingCall = channel.stream_stream(
38            '/grpc.testing.BenchmarkService/StreamingCall')
39
40
41class BenchmarkClient:
42    """Benchmark client interface that exposes a non-blocking send_request()."""
43
44    __metaclass__ = abc.ABCMeta
45
46    def __init__(self, server, config, hist):
47        # Create the stub
48        if config.HasField('security_params'):
49            creds = grpc.ssl_channel_credentials(
50                resources.test_root_certificates())
51            channel = test_common.test_secure_channel(
52                server, creds, config.security_params.server_host_override)
53        else:
54            channel = grpc.insecure_channel(server)
55
56        # waits for the channel to be ready before we start sending messages
57        grpc.channel_ready_future(channel).result()
58
59        if config.payload_config.WhichOneof('payload') == 'simple_params':
60            self._generic = False
61            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
62                channel)
63            payload = messages_pb2.Payload(
64                body='\0' * config.payload_config.simple_params.req_size)
65            self._request = messages_pb2.SimpleRequest(
66                payload=payload,
67                response_size=config.payload_config.simple_params.resp_size)
68        else:
69            self._generic = True
70            self._stub = GenericStub(channel)
71            self._request = '\0' * config.payload_config.bytebuf_params.req_size
72
73        self._hist = hist
74        self._response_callbacks = []
75
76    def add_response_callback(self, callback):
77        """callback will be invoked as callback(client, query_time)"""
78        self._response_callbacks.append(callback)
79
80    @abc.abstractmethod
81    def send_request(self):
82        """Non-blocking wrapper for a client's request operation."""
83        raise NotImplementedError()
84
85    def start(self):
86        pass
87
88    def stop(self):
89        pass
90
91    def _handle_response(self, client, query_time):
92        self._hist.add(query_time * 1e9)  # Report times in nanoseconds
93        for callback in self._response_callbacks:
94            callback(client, query_time)
95
96
97class UnarySyncBenchmarkClient(BenchmarkClient):
98
99    def __init__(self, server, config, hist):
100        super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
101        self._pool = futures.ThreadPoolExecutor(
102            max_workers=config.outstanding_rpcs_per_channel)
103
104    def send_request(self):
105        # Send requests in seperate threads to support multiple outstanding rpcs
106        # (See src/proto/grpc/testing/control.proto)
107        self._pool.submit(self._dispatch_request)
108
109    def stop(self):
110        self._pool.shutdown(wait=True)
111        self._stub = None
112
113    def _dispatch_request(self):
114        start_time = time.time()
115        self._stub.UnaryCall(self._request, _TIMEOUT)
116        end_time = time.time()
117        self._handle_response(self, end_time - start_time)
118
119
120class UnaryAsyncBenchmarkClient(BenchmarkClient):
121
122    def send_request(self):
123        # Use the Future callback api to support multiple outstanding rpcs
124        start_time = time.time()
125        response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
126        response_future.add_done_callback(
127            lambda resp: self._response_received(start_time, resp))
128
129    def _response_received(self, start_time, resp):
130        resp.result()
131        end_time = time.time()
132        self._handle_response(self, end_time - start_time)
133
134    def stop(self):
135        self._stub = None
136
137
138class _SyncStream(object):
139
140    def __init__(self, stub, generic, request, handle_response):
141        self._stub = stub
142        self._generic = generic
143        self._request = request
144        self._handle_response = handle_response
145        self._is_streaming = False
146        self._request_queue = queue.Queue()
147        self._send_time_queue = queue.Queue()
148
149    def send_request(self):
150        self._send_time_queue.put(time.time())
151        self._request_queue.put(self._request)
152
153    def start(self):
154        self._is_streaming = True
155        response_stream = self._stub.StreamingCall(self._request_generator(),
156                                                   _TIMEOUT)
157        for _ in response_stream:
158            self._handle_response(
159                self,
160                time.time() - self._send_time_queue.get_nowait())
161
162    def stop(self):
163        self._is_streaming = False
164
165    def _request_generator(self):
166        while self._is_streaming:
167            try:
168                request = self._request_queue.get(block=True, timeout=1.0)
169                yield request
170            except queue.Empty:
171                pass
172
173
174class StreamingSyncBenchmarkClient(BenchmarkClient):
175
176    def __init__(self, server, config, hist):
177        super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
178        self._pool = futures.ThreadPoolExecutor(
179            max_workers=config.outstanding_rpcs_per_channel)
180        self._streams = [
181            _SyncStream(self._stub, self._generic, self._request,
182                        self._handle_response)
183            for _ in xrange(config.outstanding_rpcs_per_channel)
184        ]
185        self._curr_stream = 0
186
187    def send_request(self):
188        # Use a round_robin scheduler to determine what stream to send on
189        self._streams[self._curr_stream].send_request()
190        self._curr_stream = (self._curr_stream + 1) % len(self._streams)
191
192    def start(self):
193        for stream in self._streams:
194            self._pool.submit(stream.start)
195
196    def stop(self):
197        for stream in self._streams:
198            stream.stop()
199        self._pool.shutdown(wait=True)
200        self._stub = None
201