1# Copyright 2023 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Wrapers for socket clients to log read and write data.""" 15from __future__ import annotations 16 17from typing import Callable, TYPE_CHECKING 18 19import errno 20import re 21import socket 22 23from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker 24 25if TYPE_CHECKING: 26 from _typeshed import ReadableBuffer 27 28 29class SocketClient: 30 """Socket transport implementation.""" 31 32 FILE_SOCKET_SERVER = 'file' 33 DEFAULT_SOCKET_SERVER = 'localhost' 34 DEFAULT_SOCKET_PORT = 33000 35 PW_RPC_MAX_PACKET_SIZE = 256 36 37 _InitArgsType = tuple[ 38 socket.AddressFamily, int # pylint: disable=no-member 39 ] 40 # Can be a string, (address, port) for AF_INET or (address, port, flowinfo, 41 # scope_id) AF_INET6. 42 _AddressType = str | tuple[str, int] | tuple[str, int, int, int] 43 44 def __init__( 45 self, 46 config: str, 47 on_disconnect: Callable[[SocketClient], None] | None = None, 48 ): 49 """Creates a socket connection. 50 51 Args: 52 config: The socket configuration. Accepted values and formats are: 53 'default' - uses the default configuration (localhost:33000) 54 'address:port' - An IPv4 address and port. 55 'address' - An IPv4 address. Uses default port 33000. 56 '[address]:port' - An IPv6 address and port. 57 '[address]' - An IPv6 address. Uses default port 33000. 58 'file:path_to_file' - A Unix socket at ``path_to_file``. 59 In the formats above,``address`` can be an actual address or a name 60 that resolves to an address through name-resolution. 61 on_disconnect: An optional callback called when the socket 62 disconnects. 63 64 Raises: 65 TypeError: The type of socket is not supported. 66 ValueError: The socket configuration is invalid. 67 """ 68 self.socket: socket.socket 69 ( 70 self._socket_init_args, 71 self._address, 72 ) = SocketClient._parse_socket_config(config) 73 self._on_disconnect = on_disconnect 74 self._connected = False 75 self.connect() 76 77 @staticmethod 78 def _parse_socket_config( 79 config: str, 80 ) -> tuple[SocketClient._InitArgsType, SocketClient._AddressType]: 81 """Sets the variables used to create a socket given a config string. 82 83 Raises: 84 TypeError: The type of socket is not supported. 85 ValueError: The socket configuration is invalid. 86 """ 87 init_args: SocketClient._InitArgsType 88 address: SocketClient._AddressType 89 90 # Check if this is using the default settings. 91 if config == 'default': 92 init_args = socket.AF_INET6, socket.SOCK_STREAM 93 address = ( 94 SocketClient.DEFAULT_SOCKET_SERVER, 95 SocketClient.DEFAULT_SOCKET_PORT, 96 ) 97 return init_args, address 98 99 # Check if this is a UNIX socket. 100 unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:' 101 if config.startswith(unix_socket_file_setting): 102 # Unix socket support is available on Windows 10 since April 103 # 2018. However, there is no Python support on Windows yet. 104 # See https://bugs.python.org/issue33408 for more information. 105 if not hasattr(socket, 'AF_UNIX'): 106 raise TypeError( 107 'Unix sockets are not supported in this environment.' 108 ) 109 init_args = ( 110 socket.AF_UNIX, # pylint: disable=no-member 111 socket.SOCK_STREAM, 112 ) 113 address = config[len(unix_socket_file_setting) :] 114 return init_args, address 115 116 # Search for IPv4 or IPv6 address or name and port. 117 # First, try to capture an IPv6 address as anything inside []. If there 118 # are no [] capture the IPv4 address. Lastly, capture the port as the 119 # numbers after :, if any. 120 match = re.match( 121 r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)' 122 r'(?P<port>[0-9]+)?', 123 config, 124 ) 125 invalid_config_message = ( 126 f'Invalid socket configuration "{config}"' 127 'Accepted values are "default", "file:<file_path>", ' 128 '"<name_or_ipv4_address>" with optional ":<port>", and ' 129 '"[<name_or_ipv6_address>]" with optional ":<port>".' 130 ) 131 if match is None: 132 raise ValueError(invalid_config_message) 133 134 info = match.groupdict() 135 if info['port']: 136 port = int(info['port']) 137 else: 138 port = SocketClient.DEFAULT_SOCKET_PORT 139 140 if info['ipv4_addr']: 141 ip_addr = info['ipv4_addr'] 142 elif info['ipv6_addr']: 143 ip_addr = info['ipv6_addr'] 144 else: 145 raise ValueError(invalid_config_message) 146 147 sock_family, sock_type, _, _, address = socket.getaddrinfo( 148 ip_addr, port, type=socket.SOCK_STREAM 149 )[0] 150 init_args = sock_family, sock_type 151 return init_args, address 152 153 def __del__(self): 154 if self._connected: 155 self.socket.close() 156 157 def write(self, data: ReadableBuffer) -> None: 158 """Writes data and detects disconnects.""" 159 if not self._connected: 160 raise Exception('Socket is not connected.') 161 try: 162 self.socket.sendall(data) 163 except socket.error as exc: 164 if isinstance(exc.args, tuple) and exc.args[0] == errno.EPIPE: 165 self._handle_disconnect() 166 else: 167 raise exc 168 169 def read(self, num_bytes: int = PW_RPC_MAX_PACKET_SIZE) -> bytes: 170 """Blocks until data is ready and reads up to num_bytes.""" 171 if not self._connected: 172 raise Exception('Socket is not connected.') 173 data = self.socket.recv(num_bytes) 174 # Since this is a blocking read, no data returned means the socket is 175 # closed. 176 if not data: 177 self._handle_disconnect() 178 return data 179 180 def connect(self) -> None: 181 """Connects to socket.""" 182 self.socket = socket.socket(*self._socket_init_args) 183 184 # Enable reusing address and port for reconnections. 185 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 186 if hasattr(socket, 'SO_REUSEPORT'): 187 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) 188 self.socket.connect(self._address) 189 self._connected = True 190 191 def _handle_disconnect(self): 192 """Escalates a socket disconnect to the user.""" 193 self.socket.close() 194 self._connected = False 195 if self._on_disconnect: 196 self._on_disconnect(self) 197 198 def fileno(self) -> int: 199 return self.socket.fileno() 200 201 202class SocketClientWithLogging(SocketClient): 203 """Socket with read and write wrappers for logging.""" 204 205 def __init__(self, *args, **kwargs): 206 super().__init__(*args, **kwargs) 207 self._bandwidth_tracker = SerialBandwidthTracker() 208 209 def read( 210 self, num_bytes: int = SocketClient.PW_RPC_MAX_PACKET_SIZE 211 ) -> bytes: 212 data = super().read(num_bytes) 213 self._bandwidth_tracker.track_read_data(data) 214 return data 215 216 def write(self, data: ReadableBuffer) -> None: 217 self._bandwidth_tracker.track_write_data(data) 218 super().write(data) 219