• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test making many calls and immediately cancelling most of them."""
15
16import threading
17import unittest
18
19from grpc._cython import cygrpc
20from grpc.framework.foundation import logging_pool
21
22from tests.unit._cython import test_utilities
23from tests.unit.framework.common import test_constants
24
25_EMPTY_FLAGS = 0
26_EMPTY_METADATA = ()
27
28_SERVER_SHUTDOWN_TAG = "server_shutdown"
29_REQUEST_CALL_TAG = "request_call"
30_RECEIVE_CLOSE_ON_SERVER_TAG = "receive_close_on_server"
31_RECEIVE_MESSAGE_TAG = "receive_message"
32_SERVER_COMPLETE_CALL_TAG = "server_complete_call"
33
34_SUCCESS_CALL_FRACTION = 1.0 / 8.0
35_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
36_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS
37
38
39class _State(object):
40    def __init__(self):
41        self.condition = threading.Condition()
42        self.handlers_released = False
43        self.parked_handlers = 0
44        self.handled_rpcs = 0
45
46
47def _is_cancellation_event(event):
48    return (
49        event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG
50        and event.batch_operations[0].cancelled()
51    )
52
53
54class _Handler(object):
55    def __init__(self, state, completion_queue, rpc_event):
56        self._state = state
57        self._lock = threading.Lock()
58        self._completion_queue = completion_queue
59        self._call = rpc_event.call
60
61    def __call__(self):
62        with self._state.condition:
63            self._state.parked_handlers += 1
64            if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
65                self._state.condition.notify_all()
66            while not self._state.handlers_released:
67                self._state.condition.wait()
68
69        with self._lock:
70            self._call.start_server_batch(
71                (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
72                _RECEIVE_CLOSE_ON_SERVER_TAG,
73            )
74            self._call.start_server_batch(
75                (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
76                _RECEIVE_MESSAGE_TAG,
77            )
78        first_event = self._completion_queue.poll()
79        if _is_cancellation_event(first_event):
80            self._completion_queue.poll()
81        else:
82            with self._lock:
83                operations = (
84                    cygrpc.SendInitialMetadataOperation(
85                        _EMPTY_METADATA, _EMPTY_FLAGS
86                    ),
87                    cygrpc.SendMessageOperation(b"\x79\x57", _EMPTY_FLAGS),
88                    cygrpc.SendStatusFromServerOperation(
89                        _EMPTY_METADATA,
90                        cygrpc.StatusCode.ok,
91                        b"test details!",
92                        _EMPTY_FLAGS,
93                    ),
94                )
95                self._call.start_server_batch(
96                    operations, _SERVER_COMPLETE_CALL_TAG
97                )
98            self._completion_queue.poll()
99            self._completion_queue.poll()
100
101
102def _serve(state, server, server_completion_queue, thread_pool):
103    for _ in range(test_constants.RPC_CONCURRENCY):
104        call_completion_queue = cygrpc.CompletionQueue()
105        server.request_call(
106            call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG
107        )
108        rpc_event = server_completion_queue.poll()
109        thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
110        with state.condition:
111            state.handled_rpcs += 1
112            if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
113                state.condition.notify_all()
114    server_completion_queue.poll()
115
116
117class _QueueDriver(object):
118    def __init__(self, condition, completion_queue, due):
119        self._condition = condition
120        self._completion_queue = completion_queue
121        self._due = due
122        self._events = []
123        self._returned = False
124
125    def start(self):
126        def in_thread():
127            while True:
128                event = self._completion_queue.poll()
129                with self._condition:
130                    self._events.append(event)
131                    self._due.remove(event.tag)
132                    self._condition.notify_all()
133                    if not self._due:
134                        self._returned = True
135                        return
136
137        thread = threading.Thread(target=in_thread)
138        thread.start()
139
140    def events(self, at_least):
141        with self._condition:
142            while len(self._events) < at_least:
143                self._condition.wait()
144            return tuple(self._events)
145
146
147class CancelManyCallsTest(unittest.TestCase):
148    def testCancelManyCalls(self):
149        server_thread_pool = logging_pool.pool(
150            test_constants.THREAD_CONCURRENCY
151        )
152
153        server_completion_queue = cygrpc.CompletionQueue()
154        server = cygrpc.Server(
155            [
156                (
157                    b"grpc.so_reuseport",
158                    0,
159                )
160            ],
161            False,
162        )
163        server.register_completion_queue(server_completion_queue)
164        port = server.add_http2_port(b"[::]:0")
165        server.start()
166        channel = cygrpc.Channel(
167            "localhost:{}".format(port).encode(), None, None
168        )
169
170        state = _State()
171
172        server_thread_args = (
173            state,
174            server,
175            server_completion_queue,
176            server_thread_pool,
177        )
178        server_thread = threading.Thread(target=_serve, args=server_thread_args)
179        server_thread.start()
180
181        client_condition = threading.Condition()
182        client_due = set()
183
184        with client_condition:
185            client_calls = []
186            for index in range(test_constants.RPC_CONCURRENCY):
187                tag = "client_complete_call_{0:04d}_tag".format(index)
188                client_call = channel.integrated_call(
189                    _EMPTY_FLAGS,
190                    b"/twinkies",
191                    None,
192                    None,
193                    _EMPTY_METADATA,
194                    None,
195                    (
196                        (
197                            (
198                                cygrpc.SendInitialMetadataOperation(
199                                    _EMPTY_METADATA, _EMPTY_FLAGS
200                                ),
201                                cygrpc.SendMessageOperation(
202                                    b"\x45\x56", _EMPTY_FLAGS
203                                ),
204                                cygrpc.SendCloseFromClientOperation(
205                                    _EMPTY_FLAGS
206                                ),
207                                cygrpc.ReceiveInitialMetadataOperation(
208                                    _EMPTY_FLAGS
209                                ),
210                                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
211                                cygrpc.ReceiveStatusOnClientOperation(
212                                    _EMPTY_FLAGS
213                                ),
214                            ),
215                            tag,
216                        ),
217                    ),
218                )
219                client_due.add(tag)
220                client_calls.append(client_call)
221
222        client_events_future = test_utilities.SimpleFuture(
223            lambda: tuple(
224                channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS)
225            )
226        )
227
228        with state.condition:
229            while True:
230                if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
231                    state.condition.wait()
232                elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
233                    state.condition.wait()
234                else:
235                    state.handlers_released = True
236                    state.condition.notify_all()
237                    break
238
239        client_events_future.result()
240        with client_condition:
241            for client_call in client_calls:
242                client_call.cancel(cygrpc.StatusCode.cancelled, "Cancelled!")
243        for _ in range(_UNSUCCESSFUL_CALLS):
244            channel.next_call_event()
245
246        channel.close(cygrpc.StatusCode.unknown, "Cancelled on channel close!")
247        with state.condition:
248            server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
249
250
251if __name__ == "__main__":
252    unittest.main(verbosity=2)
253