1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Common code used throughout tests of gRPC.""" 15 16import collections 17import threading 18 19from concurrent import futures 20import grpc 21import six 22 23INVOCATION_INITIAL_METADATA = ( 24 ('0', 'abc'), 25 ('1', 'def'), 26 ('2', 'ghi'), 27) 28SERVICE_INITIAL_METADATA = ( 29 ('3', 'jkl'), 30 ('4', 'mno'), 31 ('5', 'pqr'), 32) 33SERVICE_TERMINAL_METADATA = ( 34 ('6', 'stu'), 35 ('7', 'vwx'), 36 ('8', 'yza'), 37) 38DETAILS = 'test details' 39 40 41def metadata_transmitted(original_metadata, transmitted_metadata): 42 """Judges whether or not metadata was acceptably transmitted. 43 44 gRPC is allowed to insert key-value pairs into the metadata values given by 45 applications and to reorder key-value pairs with different keys but it is not 46 allowed to alter existing key-value pairs or to reorder key-value pairs with 47 the same key. 48 49 Args: 50 original_metadata: A metadata value used in a test of gRPC. An iterable over 51 iterables of length 2. 52 transmitted_metadata: A metadata value corresponding to original_metadata 53 after having been transmitted via gRPC. An iterable over iterables of 54 length 2. 55 56 Returns: 57 A boolean indicating whether transmitted_metadata accurately reflects 58 original_metadata after having been transmitted via gRPC. 59 """ 60 original = collections.defaultdict(list) 61 for key, value in original_metadata: 62 original[key].append(value) 63 transmitted = collections.defaultdict(list) 64 for key, value in transmitted_metadata: 65 transmitted[key].append(value) 66 67 for key, values in six.iteritems(original): 68 transmitted_values = transmitted[key] 69 transmitted_iterator = iter(transmitted_values) 70 try: 71 for value in values: 72 while True: 73 transmitted_value = next(transmitted_iterator) 74 if value == transmitted_value: 75 break 76 except StopIteration: 77 return False 78 else: 79 return True 80 81 82def test_secure_channel(target, channel_credentials, server_host_override): 83 """Creates an insecure Channel to a remote host. 84 85 Args: 86 host: The name of the remote host to which to connect. 87 port: The port of the remote host to which to connect. 88 channel_credentials: The implementations.ChannelCredentials with which to 89 connect. 90 server_host_override: The target name used for SSL host name checking. 91 92 Returns: 93 An implementations.Channel to the remote host through which RPCs may be 94 conducted. 95 """ 96 channel = grpc.secure_channel(target, channel_credentials, (( 97 'grpc.ssl_target_name_override', 98 server_host_override, 99 ),)) 100 return channel 101 102 103def test_server(max_workers=10, reuse_port=False): 104 """Creates an insecure grpc server. 105 106 These servers have SO_REUSEPORT disabled to prevent cross-talk. 107 """ 108 return grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), 109 options=(('grpc.so_reuseport', int(reuse_port)),)) 110 111 112class WaitGroup(object): 113 114 def __init__(self, n=0): 115 self.count = n 116 self.cv = threading.Condition() 117 118 def add(self, n): 119 self.cv.acquire() 120 self.count += n 121 self.cv.release() 122 123 def done(self): 124 self.cv.acquire() 125 self.count -= 1 126 if self.count == 0: 127 self.cv.notify_all() 128 self.cv.release() 129 130 def wait(self): 131 self.cv.acquire() 132 while self.count > 0: 133 self.cv.wait() 134 self.cv.release() 135