1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Common utilities for tests of the Cython layer of gRPC Python.""" 15 16import collections 17import threading 18 19from grpc._cython import cygrpc 20 21RPC_COUNT = 4000 22 23EMPTY_FLAGS = 0 24 25INVOCATION_METADATA = ( 26 ("client-md-key", "client-md-key"), 27 ("client-md-key-bin", b"\x00\x01" * 3000), 28) 29 30INITIAL_METADATA = ( 31 ("server-initial-md-key", "server-initial-md-value"), 32 ("server-initial-md-key-bin", b"\x00\x02" * 3000), 33) 34 35TRAILING_METADATA = ( 36 ("server-trailing-md-key", "server-trailing-md-value"), 37 ("server-trailing-md-key-bin", b"\x00\x03" * 3000), 38) 39 40 41class QueueDriver(object): 42 def __init__(self, condition, completion_queue): 43 self._condition = condition 44 self._completion_queue = completion_queue 45 self._due = collections.defaultdict(int) 46 self._events = collections.defaultdict(list) 47 48 def add_due(self, tags): 49 if not self._due: 50 51 def in_thread(): 52 while True: 53 event = self._completion_queue.poll() 54 with self._condition: 55 self._events[event.tag].append(event) 56 self._due[event.tag] -= 1 57 self._condition.notify_all() 58 if self._due[event.tag] <= 0: 59 self._due.pop(event.tag) 60 if not self._due: 61 return 62 63 thread = threading.Thread(target=in_thread) 64 thread.start() 65 for tag in tags: 66 self._due[tag] += 1 67 68 def event_with_tag(self, tag): 69 with self._condition: 70 while True: 71 if self._events[tag]: 72 return self._events[tag].pop(0) 73 else: 74 self._condition.wait() 75 76 77def execute_many_times(behavior): 78 return tuple(behavior() for _ in range(RPC_COUNT)) 79 80 81class OperationResult( 82 collections.namedtuple( 83 "OperationResult", 84 ( 85 "start_batch_result", 86 "completion_type", 87 "success", 88 ), 89 ) 90): 91 pass 92 93 94SUCCESSFUL_OPERATION_RESULT = OperationResult( 95 cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True 96) 97 98 99class RpcTest(object): 100 def setUp(self): 101 self.server_completion_queue = cygrpc.CompletionQueue() 102 self.server = cygrpc.Server([(b"grpc.so_reuseport", 0)], False) 103 self.server.register_completion_queue(self.server_completion_queue) 104 port = self.server.add_http2_port(b"[::]:0") 105 self.server.start() 106 self.channel = cygrpc.Channel( 107 "localhost:{}".format(port).encode(), [], None 108 ) 109 110 self._server_shutdown_tag = "server_shutdown_tag" 111 self.server_condition = threading.Condition() 112 self.server_driver = QueueDriver( 113 self.server_condition, self.server_completion_queue 114 ) 115 with self.server_condition: 116 self.server_driver.add_due( 117 { 118 self._server_shutdown_tag, 119 } 120 ) 121 122 self.client_condition = threading.Condition() 123 self.client_completion_queue = cygrpc.CompletionQueue() 124 self.client_driver = QueueDriver( 125 self.client_condition, self.client_completion_queue 126 ) 127 128 def tearDown(self): 129 self.server.shutdown( 130 self.server_completion_queue, self._server_shutdown_tag 131 ) 132 self.server.cancel_all_calls() 133