# Copyright 2019 the gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Proxies a TCP connection between a single client-server pair. This proxy is not suitable for production, but should work well for cases in which a test needs to spy on the bytes put on the wire between a server and a client. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import datetime import select import socket import threading from tests.unit.framework.common import get_socket _TCP_PROXY_BUFFER_SIZE = 1024 _TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) def _init_proxy_socket(gateway_address, gateway_port): proxy_socket = socket.create_connection((gateway_address, gateway_port)) return proxy_socket class TcpProxy(object): """Proxies a TCP connection between one client and one server.""" def __init__(self, bind_address, gateway_address, gateway_port): self._bind_address = bind_address self._gateway_address = gateway_address self._gateway_port = gateway_port self._byte_count_lock = threading.RLock() self._sent_byte_count = 0 self._received_byte_count = 0 self._stop_event = threading.Event() self._port = None self._listen_socket = None self._proxy_socket = None # The following three attributes are owned by the serving thread. self._northbound_data = b"" self._southbound_data = b"" self._client_sockets = [] self._thread = threading.Thread(target=self._run_proxy) def start(self): _, self._port, self._listen_socket = get_socket( bind_address=self._bind_address) self._proxy_socket = _init_proxy_socket(self._gateway_address, self._gateway_port) self._thread.start() def get_port(self): return self._port def _handle_reads(self, sockets_to_read): for socket_to_read in sockets_to_read: if socket_to_read is self._listen_socket: client_socket, client_address = socket_to_read.accept() self._client_sockets.append(client_socket) elif socket_to_read is self._proxy_socket: data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) with self._byte_count_lock: self._received_byte_count += len(data) self._northbound_data += data elif socket_to_read in self._client_sockets: data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) if data: with self._byte_count_lock: self._sent_byte_count += len(data) self._southbound_data += data else: self._client_sockets.remove(socket_to_read) else: raise RuntimeError('Unidentified socket appeared in read set.') def _handle_writes(self, sockets_to_write): for socket_to_write in sockets_to_write: if socket_to_write is self._proxy_socket: if self._southbound_data: self._proxy_socket.sendall(self._southbound_data) self._southbound_data = b"" elif socket_to_write in self._client_sockets: if self._northbound_data: socket_to_write.sendall(self._northbound_data) self._northbound_data = b"" def _run_proxy(self): while not self._stop_event.is_set(): expected_reads = (self._listen_socket, self._proxy_socket) + tuple( self._client_sockets) expected_writes = expected_reads sockets_to_read, sockets_to_write, _ = select.select( expected_reads, expected_writes, (), _TCP_PROXY_TIMEOUT.total_seconds()) self._handle_reads(sockets_to_read) self._handle_writes(sockets_to_write) for client_socket in self._client_sockets: client_socket.close() def stop(self): self._stop_event.set() self._thread.join() self._listen_socket.close() self._proxy_socket.close() def get_byte_count(self): with self._byte_count_lock: return self._sent_byte_count, self._received_byte_count def reset_byte_count(self): with self._byte_count_lock: self._byte_count = 0 self._received_byte_count = 0 def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop()