• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14""" Proxies a TCP connection between a single client-server pair.
15
16This proxy is not suitable for production, but should work well for cases in
17which a test needs to spy on the bytes put on the wire between a server and
18a client.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import datetime
26import select
27import socket
28import threading
29
30from tests.unit.framework.common import get_socket
31
32_TCP_PROXY_BUFFER_SIZE = 1024
33_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500)
34
35
36def _init_proxy_socket(gateway_address, gateway_port):
37    proxy_socket = socket.create_connection((gateway_address, gateway_port))
38    return proxy_socket
39
40
41class TcpProxy(object):
42    """Proxies a TCP connection between one client and one server."""
43
44    def __init__(self, bind_address, gateway_address, gateway_port):
45        self._bind_address = bind_address
46        self._gateway_address = gateway_address
47        self._gateway_port = gateway_port
48
49        self._byte_count_lock = threading.RLock()
50        self._sent_byte_count = 0
51        self._received_byte_count = 0
52
53        self._stop_event = threading.Event()
54
55        self._port = None
56        self._listen_socket = None
57        self._proxy_socket = None
58
59        # The following three attributes are owned by the serving thread.
60        self._northbound_data = b""
61        self._southbound_data = b""
62        self._client_sockets = []
63
64        self._thread = threading.Thread(target=self._run_proxy)
65
66    def start(self):
67        _, self._port, self._listen_socket = get_socket(
68            bind_address=self._bind_address)
69        self._proxy_socket = _init_proxy_socket(self._gateway_address,
70                                                self._gateway_port)
71        self._thread.start()
72
73    def get_port(self):
74        return self._port
75
76    def _handle_reads(self, sockets_to_read):
77        for socket_to_read in sockets_to_read:
78            if socket_to_read is self._listen_socket:
79                client_socket, client_address = socket_to_read.accept()
80                self._client_sockets.append(client_socket)
81            elif socket_to_read is self._proxy_socket:
82                data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
83                with self._byte_count_lock:
84                    self._received_byte_count += len(data)
85                self._northbound_data += data
86            elif socket_to_read in self._client_sockets:
87                data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
88                if data:
89                    with self._byte_count_lock:
90                        self._sent_byte_count += len(data)
91                    self._southbound_data += data
92                else:
93                    self._client_sockets.remove(socket_to_read)
94            else:
95                raise RuntimeError('Unidentified socket appeared in read set.')
96
97    def _handle_writes(self, sockets_to_write):
98        for socket_to_write in sockets_to_write:
99            if socket_to_write is self._proxy_socket:
100                if self._southbound_data:
101                    self._proxy_socket.sendall(self._southbound_data)
102                    self._southbound_data = b""
103            elif socket_to_write in self._client_sockets:
104                if self._northbound_data:
105                    socket_to_write.sendall(self._northbound_data)
106                    self._northbound_data = b""
107
108    def _run_proxy(self):
109        while not self._stop_event.is_set():
110            expected_reads = (self._listen_socket, self._proxy_socket) + tuple(
111                self._client_sockets)
112            expected_writes = expected_reads
113            sockets_to_read, sockets_to_write, _ = select.select(
114                expected_reads, expected_writes, (),
115                _TCP_PROXY_TIMEOUT.total_seconds())
116            self._handle_reads(sockets_to_read)
117            self._handle_writes(sockets_to_write)
118        for client_socket in self._client_sockets:
119            client_socket.close()
120
121    def stop(self):
122        self._stop_event.set()
123        self._thread.join()
124        self._listen_socket.close()
125        self._proxy_socket.close()
126
127    def get_byte_count(self):
128        with self._byte_count_lock:
129            return self._sent_byte_count, self._received_byte_count
130
131    def reset_byte_count(self):
132        with self._byte_count_lock:
133            self._byte_count = 0
134            self._received_byte_count = 0
135
136    def __enter__(self):
137        self.start()
138        return self
139
140    def __exit__(self, exc_type, exc_val, exc_tb):
141        self.stop()
142