1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test making many calls and immediately cancelling most of them.""" 15 16import threading 17import unittest 18 19from grpc._cython import cygrpc 20from grpc.framework.foundation import logging_pool 21from tests.unit.framework.common import test_constants 22from tests.unit._cython import test_utilities 23 24_EMPTY_FLAGS = 0 25_EMPTY_METADATA = () 26 27_SERVER_SHUTDOWN_TAG = 'server_shutdown' 28_REQUEST_CALL_TAG = 'request_call' 29_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server' 30_RECEIVE_MESSAGE_TAG = 'receive_message' 31_SERVER_COMPLETE_CALL_TAG = 'server_complete_call' 32 33_SUCCESS_CALL_FRACTION = 1.0 / 8.0 34_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) 35_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS 36 37 38class _State(object): 39 40 def __init__(self): 41 self.condition = threading.Condition() 42 self.handlers_released = False 43 self.parked_handlers = 0 44 self.handled_rpcs = 0 45 46 47def _is_cancellation_event(event): 48 return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and 49 event.batch_operations[0].cancelled()) 50 51 52class _Handler(object): 53 54 def __init__(self, state, completion_queue, rpc_event): 55 self._state = state 56 self._lock = threading.Lock() 57 self._completion_queue = completion_queue 58 self._call = rpc_event.call 59 60 def __call__(self): 61 with self._state.condition: 62 self._state.parked_handlers += 1 63 if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY: 64 self._state.condition.notify_all() 65 while not self._state.handlers_released: 66 self._state.condition.wait() 67 68 with self._lock: 69 self._call.start_server_batch( 70 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), 71 _RECEIVE_CLOSE_ON_SERVER_TAG) 72 self._call.start_server_batch( 73 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 74 _RECEIVE_MESSAGE_TAG) 75 first_event = self._completion_queue.poll() 76 if _is_cancellation_event(first_event): 77 self._completion_queue.poll() 78 else: 79 with self._lock: 80 operations = ( 81 cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, 82 _EMPTY_FLAGS), 83 cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS), 84 cygrpc.SendStatusFromServerOperation( 85 _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', 86 _EMPTY_FLAGS), 87 ) 88 self._call.start_server_batch(operations, 89 _SERVER_COMPLETE_CALL_TAG) 90 self._completion_queue.poll() 91 self._completion_queue.poll() 92 93 94def _serve(state, server, server_completion_queue, thread_pool): 95 for _ in range(test_constants.RPC_CONCURRENCY): 96 call_completion_queue = cygrpc.CompletionQueue() 97 server.request_call(call_completion_queue, server_completion_queue, 98 _REQUEST_CALL_TAG) 99 rpc_event = server_completion_queue.poll() 100 thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) 101 with state.condition: 102 state.handled_rpcs += 1 103 if test_constants.RPC_CONCURRENCY <= state.handled_rpcs: 104 state.condition.notify_all() 105 server_completion_queue.poll() 106 107 108class _QueueDriver(object): 109 110 def __init__(self, condition, completion_queue, due): 111 self._condition = condition 112 self._completion_queue = completion_queue 113 self._due = due 114 self._events = [] 115 self._returned = False 116 117 def start(self): 118 119 def in_thread(): 120 while True: 121 event = self._completion_queue.poll() 122 with self._condition: 123 self._events.append(event) 124 self._due.remove(event.tag) 125 self._condition.notify_all() 126 if not self._due: 127 self._returned = True 128 return 129 130 thread = threading.Thread(target=in_thread) 131 thread.start() 132 133 def events(self, at_least): 134 with self._condition: 135 while len(self._events) < at_least: 136 self._condition.wait() 137 return tuple(self._events) 138 139 140class CancelManyCallsTest(unittest.TestCase): 141 142 def testCancelManyCalls(self): 143 server_thread_pool = logging_pool.pool( 144 test_constants.THREAD_CONCURRENCY) 145 146 server_completion_queue = cygrpc.CompletionQueue() 147 server = cygrpc.Server([ 148 ( 149 b'grpc.so_reuseport', 150 0, 151 ), 152 ]) 153 server.register_completion_queue(server_completion_queue) 154 port = server.add_http2_port(b'[::]:0') 155 server.start() 156 channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, 157 None) 158 159 state = _State() 160 161 server_thread_args = ( 162 state, 163 server, 164 server_completion_queue, 165 server_thread_pool, 166 ) 167 server_thread = threading.Thread(target=_serve, args=server_thread_args) 168 server_thread.start() 169 170 client_condition = threading.Condition() 171 client_due = set() 172 173 with client_condition: 174 client_calls = [] 175 for index in range(test_constants.RPC_CONCURRENCY): 176 tag = 'client_complete_call_{0:04d}_tag'.format(index) 177 client_call = channel.integrated_call( 178 _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, 179 None, (( 180 ( 181 cygrpc.SendInitialMetadataOperation( 182 _EMPTY_METADATA, _EMPTY_FLAGS), 183 cygrpc.SendMessageOperation(b'\x45\x56', 184 _EMPTY_FLAGS), 185 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 186 cygrpc.ReceiveInitialMetadataOperation( 187 _EMPTY_FLAGS), 188 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 189 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 190 ), 191 tag, 192 ),)) 193 client_due.add(tag) 194 client_calls.append(client_call) 195 196 client_events_future = test_utilities.SimpleFuture( 197 lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) 198 199 with state.condition: 200 while True: 201 if state.parked_handlers < test_constants.THREAD_CONCURRENCY: 202 state.condition.wait() 203 elif state.handled_rpcs < test_constants.RPC_CONCURRENCY: 204 state.condition.wait() 205 else: 206 state.handlers_released = True 207 state.condition.notify_all() 208 break 209 210 client_events_future.result() 211 with client_condition: 212 for client_call in client_calls: 213 client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') 214 for _ in range(_UNSUCCESSFUL_CALLS): 215 channel.next_call_event() 216 217 channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') 218 with state.condition: 219 server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) 220 221 222if __name__ == '__main__': 223 unittest.main(verbosity=2) 224