1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Common utilities for tests of the Cython layer of gRPC Python.""" 15 16import collections 17import threading 18 19from grpc._cython import cygrpc 20 21RPC_COUNT = 4000 22 23EMPTY_FLAGS = 0 24 25INVOCATION_METADATA = ( 26 ('client-md-key', 'client-md-key'), 27 ('client-md-key-bin', b'\x00\x01' * 3000), 28) 29 30INITIAL_METADATA = ( 31 ('server-initial-md-key', 'server-initial-md-value'), 32 ('server-initial-md-key-bin', b'\x00\x02' * 3000), 33) 34 35TRAILING_METADATA = ( 36 ('server-trailing-md-key', 'server-trailing-md-value'), 37 ('server-trailing-md-key-bin', b'\x00\x03' * 3000), 38) 39 40 41class QueueDriver(object): 42 43 def __init__(self, condition, completion_queue): 44 self._condition = condition 45 self._completion_queue = completion_queue 46 self._due = collections.defaultdict(int) 47 self._events = collections.defaultdict(list) 48 49 def add_due(self, tags): 50 if not self._due: 51 52 def in_thread(): 53 while True: 54 event = self._completion_queue.poll() 55 with self._condition: 56 self._events[event.tag].append(event) 57 self._due[event.tag] -= 1 58 self._condition.notify_all() 59 if self._due[event.tag] <= 0: 60 self._due.pop(event.tag) 61 if not self._due: 62 return 63 64 thread = threading.Thread(target=in_thread) 65 thread.start() 66 for tag in tags: 67 self._due[tag] += 1 68 69 def event_with_tag(self, tag): 70 with self._condition: 71 while True: 72 if self._events[tag]: 73 return self._events[tag].pop(0) 74 else: 75 self._condition.wait() 76 77 78def execute_many_times(behavior): 79 return tuple(behavior() for _ in range(RPC_COUNT)) 80 81 82class OperationResult( 83 collections.namedtuple('OperationResult', ( 84 'start_batch_result', 85 'completion_type', 86 'success', 87 ))): 88 pass 89 90 91SUCCESSFUL_OPERATION_RESULT = OperationResult( 92 cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True) 93 94 95class RpcTest(object): 96 97 def setUp(self): 98 self.server_completion_queue = cygrpc.CompletionQueue() 99 self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)], False) 100 self.server.register_completion_queue(self.server_completion_queue) 101 port = self.server.add_http2_port(b'[::]:0') 102 self.server.start() 103 self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], 104 None) 105 106 self._server_shutdown_tag = 'server_shutdown_tag' 107 self.server_condition = threading.Condition() 108 self.server_driver = QueueDriver(self.server_condition, 109 self.server_completion_queue) 110 with self.server_condition: 111 self.server_driver.add_due({ 112 self._server_shutdown_tag, 113 }) 114 115 self.client_condition = threading.Condition() 116 self.client_completion_queue = cygrpc.CompletionQueue() 117 self.client_driver = QueueDriver(self.client_condition, 118 self.client_completion_queue) 119 120 def tearDown(self): 121 self.server.shutdown(self.server_completion_queue, 122 self._server_shutdown_tag) 123 self.server.cancel_all_calls() 124