• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common utilities for tests of the Cython layer of gRPC Python."""
15
16import collections
17import threading
18
19from grpc._cython import cygrpc
20
21RPC_COUNT = 4000
22
23EMPTY_FLAGS = 0
24
25INVOCATION_METADATA = (
26    ('client-md-key', 'client-md-key'),
27    ('client-md-key-bin', b'\x00\x01' * 3000),
28)
29
30INITIAL_METADATA = (
31    ('server-initial-md-key', 'server-initial-md-value'),
32    ('server-initial-md-key-bin', b'\x00\x02' * 3000),
33)
34
35TRAILING_METADATA = (
36    ('server-trailing-md-key', 'server-trailing-md-value'),
37    ('server-trailing-md-key-bin', b'\x00\x03' * 3000),
38)
39
40
41class QueueDriver(object):
42
43    def __init__(self, condition, completion_queue):
44        self._condition = condition
45        self._completion_queue = completion_queue
46        self._due = collections.defaultdict(int)
47        self._events = collections.defaultdict(list)
48
49    def add_due(self, tags):
50        if not self._due:
51
52            def in_thread():
53                while True:
54                    event = self._completion_queue.poll()
55                    with self._condition:
56                        self._events[event.tag].append(event)
57                        self._due[event.tag] -= 1
58                        self._condition.notify_all()
59                        if self._due[event.tag] <= 0:
60                            self._due.pop(event.tag)
61                            if not self._due:
62                                return
63
64            thread = threading.Thread(target=in_thread)
65            thread.start()
66        for tag in tags:
67            self._due[tag] += 1
68
69    def event_with_tag(self, tag):
70        with self._condition:
71            while True:
72                if self._events[tag]:
73                    return self._events[tag].pop(0)
74                else:
75                    self._condition.wait()
76
77
78def execute_many_times(behavior):
79    return tuple(behavior() for _ in range(RPC_COUNT))
80
81
82class OperationResult(
83        collections.namedtuple('OperationResult', (
84            'start_batch_result',
85            'completion_type',
86            'success',
87        ))):
88    pass
89
90
91SUCCESSFUL_OPERATION_RESULT = OperationResult(
92    cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True)
93
94
95class RpcTest(object):
96
97    def setUp(self):
98        self.server_completion_queue = cygrpc.CompletionQueue()
99        self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)], False)
100        self.server.register_completion_queue(self.server_completion_queue)
101        port = self.server.add_http2_port(b'[::]:0')
102        self.server.start()
103        self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [],
104                                      None)
105
106        self._server_shutdown_tag = 'server_shutdown_tag'
107        self.server_condition = threading.Condition()
108        self.server_driver = QueueDriver(self.server_condition,
109                                         self.server_completion_queue)
110        with self.server_condition:
111            self.server_driver.add_due({
112                self._server_shutdown_tag,
113            })
114
115        self.client_condition = threading.Condition()
116        self.client_completion_queue = cygrpc.CompletionQueue()
117        self.client_driver = QueueDriver(self.client_condition,
118                                         self.client_completion_queue)
119
120    def tearDown(self):
121        self.server.shutdown(self.server_completion_queue,
122                             self._server_shutdown_tag)
123        self.server.cancel_all_calls()
124