• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common utilities for tests of the Cython layer of gRPC Python."""
15
16import collections
17import threading
18
19from grpc._cython import cygrpc
20
21RPC_COUNT = 4000
22
23EMPTY_FLAGS = 0
24
25INVOCATION_METADATA = (
26    ("client-md-key", "client-md-key"),
27    ("client-md-key-bin", b"\x00\x01" * 3000),
28)
29
30INITIAL_METADATA = (
31    ("server-initial-md-key", "server-initial-md-value"),
32    ("server-initial-md-key-bin", b"\x00\x02" * 3000),
33)
34
35TRAILING_METADATA = (
36    ("server-trailing-md-key", "server-trailing-md-value"),
37    ("server-trailing-md-key-bin", b"\x00\x03" * 3000),
38)
39
40
41class QueueDriver(object):
42    def __init__(self, condition, completion_queue):
43        self._condition = condition
44        self._completion_queue = completion_queue
45        self._due = collections.defaultdict(int)
46        self._events = collections.defaultdict(list)
47
48    def add_due(self, tags):
49        if not self._due:
50
51            def in_thread():
52                while True:
53                    event = self._completion_queue.poll()
54                    with self._condition:
55                        self._events[event.tag].append(event)
56                        self._due[event.tag] -= 1
57                        self._condition.notify_all()
58                        if self._due[event.tag] <= 0:
59                            self._due.pop(event.tag)
60                            if not self._due:
61                                return
62
63            thread = threading.Thread(target=in_thread)
64            thread.start()
65        for tag in tags:
66            self._due[tag] += 1
67
68    def event_with_tag(self, tag):
69        with self._condition:
70            while True:
71                if self._events[tag]:
72                    return self._events[tag].pop(0)
73                else:
74                    self._condition.wait()
75
76
77def execute_many_times(behavior):
78    return tuple(behavior() for _ in range(RPC_COUNT))
79
80
81class OperationResult(
82    collections.namedtuple(
83        "OperationResult",
84        (
85            "start_batch_result",
86            "completion_type",
87            "success",
88        ),
89    )
90):
91    pass
92
93
94SUCCESSFUL_OPERATION_RESULT = OperationResult(
95    cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True
96)
97
98
99class RpcTest(object):
100    def setUp(self):
101        self.server_completion_queue = cygrpc.CompletionQueue()
102        self.server = cygrpc.Server([(b"grpc.so_reuseport", 0)], False)
103        self.server.register_completion_queue(self.server_completion_queue)
104        port = self.server.add_http2_port(b"[::]:0")
105        self.server.start()
106        self.channel = cygrpc.Channel(
107            "localhost:{}".format(port).encode(), [], None
108        )
109
110        self._server_shutdown_tag = "server_shutdown_tag"
111        self.server_condition = threading.Condition()
112        self.server_driver = QueueDriver(
113            self.server_condition, self.server_completion_queue
114        )
115        with self.server_condition:
116            self.server_driver.add_due(
117                {
118                    self._server_shutdown_tag,
119                }
120            )
121
122        self.client_condition = threading.Condition()
123        self.client_completion_queue = cygrpc.CompletionQueue()
124        self.client_driver = QueueDriver(
125            self.client_condition, self.client_completion_queue
126        )
127
128    def tearDown(self):
129        self.server.shutdown(
130            self.server_completion_queue, self._server_shutdown_tag
131        )
132        self.server.cancel_all_calls()
133