1# Copyright 2019 the gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" Proxies a TCP connection between a single client-server pair. 15 16This proxy is not suitable for production, but should work well for cases in 17which a test needs to spy on the bytes put on the wire between a server and 18a client. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import datetime 26import select 27import socket 28import threading 29 30from tests.unit.framework.common import get_socket 31 32_TCP_PROXY_BUFFER_SIZE = 1024 33_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) 34 35 36def _init_proxy_socket(gateway_address, gateway_port): 37 proxy_socket = socket.create_connection((gateway_address, gateway_port)) 38 return proxy_socket 39 40 41class TcpProxy(object): 42 """Proxies a TCP connection between one client and one server.""" 43 44 def __init__(self, bind_address, gateway_address, gateway_port): 45 self._bind_address = bind_address 46 self._gateway_address = gateway_address 47 self._gateway_port = gateway_port 48 49 self._byte_count_lock = threading.RLock() 50 self._sent_byte_count = 0 51 self._received_byte_count = 0 52 53 self._stop_event = threading.Event() 54 55 self._port = None 56 self._listen_socket = None 57 self._proxy_socket = None 58 59 # The following three attributes are owned by the serving thread. 60 self._northbound_data = b"" 61 self._southbound_data = b"" 62 self._client_sockets = [] 63 64 self._thread = threading.Thread(target=self._run_proxy) 65 66 def start(self): 67 _, self._port, self._listen_socket = get_socket( 68 bind_address=self._bind_address) 69 self._proxy_socket = _init_proxy_socket(self._gateway_address, 70 self._gateway_port) 71 self._thread.start() 72 73 def get_port(self): 74 return self._port 75 76 def _handle_reads(self, sockets_to_read): 77 for socket_to_read in sockets_to_read: 78 if socket_to_read is self._listen_socket: 79 client_socket, client_address = socket_to_read.accept() 80 self._client_sockets.append(client_socket) 81 elif socket_to_read is self._proxy_socket: 82 data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) 83 with self._byte_count_lock: 84 self._received_byte_count += len(data) 85 self._northbound_data += data 86 elif socket_to_read in self._client_sockets: 87 data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) 88 if data: 89 with self._byte_count_lock: 90 self._sent_byte_count += len(data) 91 self._southbound_data += data 92 else: 93 self._client_sockets.remove(socket_to_read) 94 else: 95 raise RuntimeError('Unidentified socket appeared in read set.') 96 97 def _handle_writes(self, sockets_to_write): 98 for socket_to_write in sockets_to_write: 99 if socket_to_write is self._proxy_socket: 100 if self._southbound_data: 101 self._proxy_socket.sendall(self._southbound_data) 102 self._southbound_data = b"" 103 elif socket_to_write in self._client_sockets: 104 if self._northbound_data: 105 socket_to_write.sendall(self._northbound_data) 106 self._northbound_data = b"" 107 108 def _run_proxy(self): 109 while not self._stop_event.is_set(): 110 expected_reads = (self._listen_socket, self._proxy_socket) + tuple( 111 self._client_sockets) 112 expected_writes = expected_reads 113 sockets_to_read, sockets_to_write, _ = select.select( 114 expected_reads, expected_writes, (), 115 _TCP_PROXY_TIMEOUT.total_seconds()) 116 self._handle_reads(sockets_to_read) 117 self._handle_writes(sockets_to_write) 118 for client_socket in self._client_sockets: 119 client_socket.close() 120 121 def stop(self): 122 self._stop_event.set() 123 self._thread.join() 124 self._listen_socket.close() 125 self._proxy_socket.close() 126 127 def get_byte_count(self): 128 with self._byte_count_lock: 129 return self._sent_byte_count, self._received_byte_count 130 131 def reset_byte_count(self): 132 with self._byte_count_lock: 133 self._byte_count = 0 134 self._received_byte_count = 0 135 136 def __enter__(self): 137 self.start() 138 return self 139 140 def __exit__(self, exc_type, exc_val, exc_tb): 141 self.stop() 142