• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Wrapers for socket clients to log read and write data."""
15from __future__ import annotations
16
17from typing import Callable, TYPE_CHECKING
18
19import errno
20import re
21import socket
22
23from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker
24
25if TYPE_CHECKING:
26    from _typeshed import ReadableBuffer
27
28
29class SocketClient:
30    """Socket transport implementation."""
31
32    FILE_SOCKET_SERVER = 'file'
33    DEFAULT_SOCKET_SERVER = 'localhost'
34    DEFAULT_SOCKET_PORT = 33000
35    PW_RPC_MAX_PACKET_SIZE = 256
36
37    _InitArgsType = tuple[
38        socket.AddressFamily, int  # pylint: disable=no-member
39    ]
40    # Can be a string, (address, port) for AF_INET or (address, port, flowinfo,
41    # scope_id) AF_INET6.
42    _AddressType = str | tuple[str, int] | tuple[str, int, int, int]
43
44    def __init__(
45        self,
46        config: str,
47        on_disconnect: Callable[[SocketClient], None] | None = None,
48    ):
49        """Creates a socket connection.
50
51        Args:
52          config: The socket configuration. Accepted values and formats are:
53            'default' - uses the default configuration (localhost:33000)
54            'address:port' - An IPv4 address and port.
55            'address' - An IPv4 address. Uses default port 33000.
56            '[address]:port' - An IPv6 address and port.
57            '[address]' - An IPv6 address. Uses default port 33000.
58            'file:path_to_file' - A Unix socket at ``path_to_file``.
59            In the formats above,``address`` can be an actual address or a name
60            that resolves to an address through name-resolution.
61          on_disconnect: An optional callback called when the socket
62            disconnects.
63
64        Raises:
65          TypeError: The type of socket is not supported.
66          ValueError: The socket configuration is invalid.
67        """
68        self.socket: socket.socket
69        (
70            self._socket_init_args,
71            self._address,
72        ) = SocketClient._parse_socket_config(config)
73        self._on_disconnect = on_disconnect
74        self._connected = False
75        self.connect()
76
77    @staticmethod
78    def _parse_socket_config(
79        config: str,
80    ) -> tuple[SocketClient._InitArgsType, SocketClient._AddressType]:
81        """Sets the variables used to create a socket given a config string.
82
83        Raises:
84          TypeError: The type of socket is not supported.
85          ValueError: The socket configuration is invalid.
86        """
87        init_args: SocketClient._InitArgsType
88        address: SocketClient._AddressType
89
90        # Check if this is using the default settings.
91        if config == 'default':
92            init_args = socket.AF_INET6, socket.SOCK_STREAM
93            address = (
94                SocketClient.DEFAULT_SOCKET_SERVER,
95                SocketClient.DEFAULT_SOCKET_PORT,
96            )
97            return init_args, address
98
99        # Check if this is a UNIX socket.
100        unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:'
101        if config.startswith(unix_socket_file_setting):
102            # Unix socket support is available on Windows 10 since April
103            # 2018. However, there is no Python support on Windows yet.
104            # See https://bugs.python.org/issue33408 for more information.
105            if not hasattr(socket, 'AF_UNIX'):
106                raise TypeError(
107                    'Unix sockets are not supported in this environment.'
108                )
109            init_args = (
110                socket.AF_UNIX,  # pylint: disable=no-member
111                socket.SOCK_STREAM,
112            )
113            address = config[len(unix_socket_file_setting) :]
114            return init_args, address
115
116        # Search for IPv4 or IPv6 address or name and port.
117        # First, try to capture an IPv6 address as anything inside []. If there
118        # are no [] capture the IPv4 address. Lastly, capture the port as the
119        # numbers after :, if any.
120        match = re.match(
121            r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)'
122            r'(?P<port>[0-9]+)?',
123            config,
124        )
125        invalid_config_message = (
126            f'Invalid socket configuration "{config}"'
127            'Accepted values are "default", "file:<file_path>", '
128            '"<name_or_ipv4_address>" with optional ":<port>", and '
129            '"[<name_or_ipv6_address>]" with optional ":<port>".'
130        )
131        if match is None:
132            raise ValueError(invalid_config_message)
133
134        info = match.groupdict()
135        if info['port']:
136            port = int(info['port'])
137        else:
138            port = SocketClient.DEFAULT_SOCKET_PORT
139
140        if info['ipv4_addr']:
141            ip_addr = info['ipv4_addr']
142        elif info['ipv6_addr']:
143            ip_addr = info['ipv6_addr']
144        else:
145            raise ValueError(invalid_config_message)
146
147        sock_family, sock_type, _, _, address = socket.getaddrinfo(
148            ip_addr, port, type=socket.SOCK_STREAM
149        )[0]
150        init_args = sock_family, sock_type
151        return init_args, address
152
153    def __del__(self):
154        if self._connected:
155            self.socket.close()
156
157    def write(self, data: ReadableBuffer) -> None:
158        """Writes data and detects disconnects."""
159        if not self._connected:
160            raise Exception('Socket is not connected.')
161        try:
162            self.socket.sendall(data)
163        except socket.error as exc:
164            if isinstance(exc.args, tuple) and exc.args[0] == errno.EPIPE:
165                self._handle_disconnect()
166            else:
167                raise exc
168
169    def read(self, num_bytes: int = PW_RPC_MAX_PACKET_SIZE) -> bytes:
170        """Blocks until data is ready and reads up to num_bytes."""
171        if not self._connected:
172            raise Exception('Socket is not connected.')
173        data = self.socket.recv(num_bytes)
174        # Since this is a blocking read, no data returned means the socket is
175        # closed.
176        if not data:
177            self._handle_disconnect()
178        return data
179
180    def connect(self) -> None:
181        """Connects to socket."""
182        self.socket = socket.socket(*self._socket_init_args)
183
184        # Enable reusing address and port for reconnections.
185        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
186        if hasattr(socket, 'SO_REUSEPORT'):
187            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
188        self.socket.connect(self._address)
189        self._connected = True
190
191    def _handle_disconnect(self):
192        """Escalates a socket disconnect to the user."""
193        self.socket.close()
194        self._connected = False
195        if self._on_disconnect:
196            self._on_disconnect(self)
197
198    def fileno(self) -> int:
199        return self.socket.fileno()
200
201
202class SocketClientWithLogging(SocketClient):
203    """Socket with read and write wrappers for logging."""
204
205    def __init__(self, *args, **kwargs):
206        super().__init__(*args, **kwargs)
207        self._bandwidth_tracker = SerialBandwidthTracker()
208
209    def read(
210        self, num_bytes: int = SocketClient.PW_RPC_MAX_PACKET_SIZE
211    ) -> bytes:
212        data = super().read(num_bytes)
213        self._bandwidth_tracker.track_read_data(data)
214        return data
215
216    def write(self, data: ReadableBuffer) -> None:
217        self._bandwidth_tracker.track_write_data(data)
218        super().write(data)
219