• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common code used throughout tests of gRPC."""
15
16import collections
17import threading
18
19from concurrent import futures
20import grpc
21import six
22
23INVOCATION_INITIAL_METADATA = (
24    ('0', 'abc'),
25    ('1', 'def'),
26    ('2', 'ghi'),
27)
28SERVICE_INITIAL_METADATA = (
29    ('3', 'jkl'),
30    ('4', 'mno'),
31    ('5', 'pqr'),
32)
33SERVICE_TERMINAL_METADATA = (
34    ('6', 'stu'),
35    ('7', 'vwx'),
36    ('8', 'yza'),
37)
38DETAILS = 'test details'
39
40
41def metadata_transmitted(original_metadata, transmitted_metadata):
42    """Judges whether or not metadata was acceptably transmitted.
43
44  gRPC is allowed to insert key-value pairs into the metadata values given by
45  applications and to reorder key-value pairs with different keys but it is not
46  allowed to alter existing key-value pairs or to reorder key-value pairs with
47  the same key.
48
49  Args:
50    original_metadata: A metadata value used in a test of gRPC. An iterable over
51      iterables of length 2.
52    transmitted_metadata: A metadata value corresponding to original_metadata
53      after having been transmitted via gRPC. An iterable over iterables of
54      length 2.
55
56  Returns:
57     A boolean indicating whether transmitted_metadata accurately reflects
58      original_metadata after having been transmitted via gRPC.
59  """
60    original = collections.defaultdict(list)
61    for key, value in original_metadata:
62        original[key].append(value)
63    transmitted = collections.defaultdict(list)
64    for key, value in transmitted_metadata:
65        transmitted[key].append(value)
66
67    for key, values in six.iteritems(original):
68        transmitted_values = transmitted[key]
69        transmitted_iterator = iter(transmitted_values)
70        try:
71            for value in values:
72                while True:
73                    transmitted_value = next(transmitted_iterator)
74                    if value == transmitted_value:
75                        break
76        except StopIteration:
77            return False
78    else:
79        return True
80
81
82def test_secure_channel(target, channel_credentials, server_host_override):
83    """Creates an insecure Channel to a remote host.
84
85  Args:
86    host: The name of the remote host to which to connect.
87    port: The port of the remote host to which to connect.
88    channel_credentials: The implementations.ChannelCredentials with which to
89      connect.
90    server_host_override: The target name used for SSL host name checking.
91
92  Returns:
93    An implementations.Channel to the remote host through which RPCs may be
94      conducted.
95  """
96    channel = grpc.secure_channel(target, channel_credentials, ((
97        'grpc.ssl_target_name_override',
98        server_host_override,
99    ),))
100    return channel
101
102
103def test_server(max_workers=10, reuse_port=False):
104    """Creates an insecure grpc server.
105
106     These servers have SO_REUSEPORT disabled to prevent cross-talk.
107     """
108    return grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers),
109                       options=(('grpc.so_reuseport', int(reuse_port)),))
110
111
112class WaitGroup(object):
113
114    def __init__(self, n=0):
115        self.count = n
116        self.cv = threading.Condition()
117
118    def add(self, n):
119        self.cv.acquire()
120        self.count += n
121        self.cv.release()
122
123    def done(self):
124        self.cv.acquire()
125        self.count -= 1
126        if self.count == 0:
127            self.cv.notify_all()
128        self.cv.release()
129
130    def wait(self):
131        self.cv.acquire()
132        while self.count > 0:
133            self.cv.wait()
134        self.cv.release()
135