• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
15
16import abc
17import threading
18import time
19
20from concurrent import futures
21from six.moves import queue
22
23import grpc
24from src.proto.grpc.testing import messages_pb2
25from src.proto.grpc.testing import benchmark_service_pb2_grpc
26from tests.unit import resources
27from tests.unit import test_common
28
29_TIMEOUT = 60 * 60 * 24
30
31
32class GenericStub(object):
33
34    def __init__(self, channel):
35        self.UnaryCall = channel.unary_unary(
36            '/grpc.testing.BenchmarkService/UnaryCall')
37        self.StreamingCall = channel.stream_stream(
38            '/grpc.testing.BenchmarkService/StreamingCall')
39
40
41class BenchmarkClient:
42    """Benchmark client interface that exposes a non-blocking send_request()."""
43
44    __metaclass__ = abc.ABCMeta
45
46    def __init__(self, server, config, hist):
47        # Create the stub
48        if config.HasField('security_params'):
49            creds = grpc.ssl_channel_credentials(
50                resources.test_root_certificates())
51            channel = test_common.test_secure_channel(
52                server, creds, config.security_params.server_host_override)
53        else:
54            channel = grpc.insecure_channel(server)
55
56        # waits for the channel to be ready before we start sending messages
57        grpc.channel_ready_future(channel).result()
58
59        if config.payload_config.WhichOneof('payload') == 'simple_params':
60            self._generic = False
61            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
62                channel)
63            payload = messages_pb2.Payload(
64                body=bytes(b'\0' *
65                           config.payload_config.simple_params.req_size))
66            self._request = messages_pb2.SimpleRequest(
67                payload=payload,
68                response_size=config.payload_config.simple_params.resp_size)
69        else:
70            self._generic = True
71            self._stub = GenericStub(channel)
72            self._request = bytes(b'\0' *
73                                  config.payload_config.bytebuf_params.req_size)
74
75        self._hist = hist
76        self._response_callbacks = []
77
78    def add_response_callback(self, callback):
79        """callback will be invoked as callback(client, query_time)"""
80        self._response_callbacks.append(callback)
81
82    @abc.abstractmethod
83    def send_request(self):
84        """Non-blocking wrapper for a client's request operation."""
85        raise NotImplementedError()
86
87    def start(self):
88        pass
89
90    def stop(self):
91        pass
92
93    def _handle_response(self, client, query_time):
94        self._hist.add(query_time * 1e9)  # Report times in nanoseconds
95        for callback in self._response_callbacks:
96            callback(client, query_time)
97
98
99class UnarySyncBenchmarkClient(BenchmarkClient):
100
101    def __init__(self, server, config, hist):
102        super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
103        self._pool = futures.ThreadPoolExecutor(
104            max_workers=config.outstanding_rpcs_per_channel)
105
106    def send_request(self):
107        # Send requests in separate threads to support multiple outstanding rpcs
108        # (See src/proto/grpc/testing/control.proto)
109        self._pool.submit(self._dispatch_request)
110
111    def stop(self):
112        self._pool.shutdown(wait=True)
113        self._stub = None
114
115    def _dispatch_request(self):
116        start_time = time.time()
117        self._stub.UnaryCall(self._request, _TIMEOUT)
118        end_time = time.time()
119        self._handle_response(self, end_time - start_time)
120
121
122class UnaryAsyncBenchmarkClient(BenchmarkClient):
123
124    def send_request(self):
125        # Use the Future callback api to support multiple outstanding rpcs
126        start_time = time.time()
127        response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
128        response_future.add_done_callback(
129            lambda resp: self._response_received(start_time, resp))
130
131    def _response_received(self, start_time, resp):
132        resp.result()
133        end_time = time.time()
134        self._handle_response(self, end_time - start_time)
135
136    def stop(self):
137        self._stub = None
138
139
140class _SyncStream(object):
141
142    def __init__(self, stub, generic, request, handle_response):
143        self._stub = stub
144        self._generic = generic
145        self._request = request
146        self._handle_response = handle_response
147        self._is_streaming = False
148        self._request_queue = queue.Queue()
149        self._send_time_queue = queue.Queue()
150
151    def send_request(self):
152        self._send_time_queue.put(time.time())
153        self._request_queue.put(self._request)
154
155    def start(self):
156        self._is_streaming = True
157        response_stream = self._stub.StreamingCall(self._request_generator(),
158                                                   _TIMEOUT)
159        for _ in response_stream:
160            self._handle_response(
161                self,
162                time.time() - self._send_time_queue.get_nowait())
163
164    def stop(self):
165        self._is_streaming = False
166
167    def _request_generator(self):
168        while self._is_streaming:
169            try:
170                request = self._request_queue.get(block=True, timeout=1.0)
171                yield request
172            except queue.Empty:
173                pass
174
175
176class StreamingSyncBenchmarkClient(BenchmarkClient):
177
178    def __init__(self, server, config, hist):
179        super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
180        self._pool = futures.ThreadPoolExecutor(
181            max_workers=config.outstanding_rpcs_per_channel)
182        self._streams = [
183            _SyncStream(self._stub, self._generic, self._request,
184                        self._handle_response)
185            for _ in range(config.outstanding_rpcs_per_channel)
186        ]
187        self._curr_stream = 0
188
189    def send_request(self):
190        # Use a round_robin scheduler to determine what stream to send on
191        self._streams[self._curr_stream].send_request()
192        self._curr_stream = (self._curr_stream + 1) % len(self._streams)
193
194    def start(self):
195        for stream in self._streams:
196            self._pool.submit(stream.start)
197
198    def stop(self):
199        for stream in self._streams:
200            stream.stop()
201        self._pool.shutdown(wait=True)
202        self._stub = None
203