• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common code used throughout tests of gRPC."""
15
16import ast
17import collections
18from concurrent import futures
19import os
20import threading
21
22import grpc
23
24INVOCATION_INITIAL_METADATA = (
25    ("0", "abc"),
26    ("1", "def"),
27    ("2", "ghi"),
28)
29SERVICE_INITIAL_METADATA = (
30    ("3", "jkl"),
31    ("4", "mno"),
32    ("5", "pqr"),
33)
34SERVICE_TERMINAL_METADATA = (
35    ("6", "stu"),
36    ("7", "vwx"),
37    ("8", "yza"),
38)
39DETAILS = "test details"
40
41
42def metadata_transmitted(original_metadata, transmitted_metadata):
43    """Judges whether or not metadata was acceptably transmitted.
44
45    gRPC is allowed to insert key-value pairs into the metadata values given by
46    applications and to reorder key-value pairs with different keys but it is not
47    allowed to alter existing key-value pairs or to reorder key-value pairs with
48    the same key.
49
50    Args:
51      original_metadata: A metadata value used in a test of gRPC. An iterable over
52        iterables of length 2.
53      transmitted_metadata: A metadata value corresponding to original_metadata
54        after having been transmitted via gRPC. An iterable over iterables of
55        length 2.
56
57    Returns:
58       A boolean indicating whether transmitted_metadata accurately reflects
59        original_metadata after having been transmitted via gRPC.
60    """
61    original = collections.defaultdict(list)
62    for key, value in original_metadata:
63        original[key].append(value)
64    transmitted = collections.defaultdict(list)
65    for key, value in transmitted_metadata:
66        transmitted[key].append(value)
67
68    for key, values in original.items():
69        transmitted_values = transmitted[key]
70        transmitted_iterator = iter(transmitted_values)
71        try:
72            for value in values:
73                while True:
74                    transmitted_value = next(transmitted_iterator)
75                    if value == transmitted_value:
76                        break
77        except StopIteration:
78            return False
79    else:
80        return True
81
82
83def test_secure_channel(target, channel_credentials, server_host_override):
84    """Creates an insecure Channel to a remote host.
85
86    Args:
87      host: The name of the remote host to which to connect.
88      port: The port of the remote host to which to connect.
89      channel_credentials: The implementations.ChannelCredentials with which to
90        connect.
91      server_host_override: The target name used for SSL host name checking.
92
93    Returns:
94      An implementations.Channel to the remote host through which RPCs may be
95        conducted.
96    """
97    channel = grpc.secure_channel(
98        target,
99        channel_credentials,
100        (
101            (
102                "grpc.ssl_target_name_override",
103                server_host_override,
104            ),
105        ),
106    )
107    return channel
108
109
110def test_server(max_workers=10, reuse_port=False):
111    """Creates an insecure grpc server.
112
113    These servers have SO_REUSEPORT disabled to prevent cross-talk.
114    """
115    server_kwargs = os.environ.get("GRPC_ADDITIONAL_SERVER_KWARGS", "{}")
116    server_kwargs = ast.literal_eval(server_kwargs)
117    return grpc.server(
118        futures.ThreadPoolExecutor(max_workers=max_workers),
119        options=(("grpc.so_reuseport", int(reuse_port)),),
120        **server_kwargs,
121    )
122
123
124class WaitGroup(object):
125    def __init__(self, n=0):
126        self.count = n
127        self.cv = threading.Condition()
128
129    def add(self, n):
130        self.cv.acquire()
131        self.count += n
132        self.cv.release()
133
134    def done(self):
135        self.cv.acquire()
136        self.count -= 1
137        if self.count == 0:
138            self.cv.notify_all()
139        self.cv.release()
140
141    def wait(self):
142        self.cv.acquire()
143        while self.count > 0:
144            self.cv.wait()
145        self.cv.release()
146
147
148def running_under_gevent():
149    try:
150        from gevent import monkey
151        import gevent.socket
152    except ImportError:
153        return False
154    else:
155        import socket
156
157        return socket.socket is gevent.socket.socket
158