• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15This contains helpers for gRPC services defined in
16https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
17"""
18import ipaddress
19import logging
20from typing import Iterator, Optional
21
22import grpc
23from grpc_channelz.v1 import channelz_pb2
24from grpc_channelz.v1 import channelz_pb2_grpc
25
26import framework.rpc
27
28logger = logging.getLogger(__name__)
29
30# Type aliases
31# Channel
32Channel = channelz_pb2.Channel
33ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
34ChannelState = ChannelConnectivityState.State  # pylint: disable=no-member
35_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
36_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
37# Subchannel
38Subchannel = channelz_pb2.Subchannel
39_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
40_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
41# Server
42Server = channelz_pb2.Server
43_GetServersRequest = channelz_pb2.GetServersRequest
44_GetServersResponse = channelz_pb2.GetServersResponse
45# Sockets
46Socket = channelz_pb2.Socket
47SocketRef = channelz_pb2.SocketRef
48_GetSocketRequest = channelz_pb2.GetSocketRequest
49_GetSocketResponse = channelz_pb2.GetSocketResponse
50Address = channelz_pb2.Address
51Security = channelz_pb2.Security
52# Server Sockets
53_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
54_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
55
56
57class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
58    stub: channelz_pb2_grpc.ChannelzStub
59
60    def __init__(self, channel: grpc.Channel):
61        super().__init__(channel, channelz_pb2_grpc.ChannelzStub)
62
63    @staticmethod
64    def is_sock_tcpip_address(address: Address):
65        return address.WhichOneof('address') == 'tcpip_address'
66
67    @staticmethod
68    def is_ipv4(tcpip_address: Address.TcpIpAddress):
69        # According to proto, tcpip_address.ip_address is either IPv4 or IPv6.
70        # Correspondingly, it's either 4 bytes or 16 bytes in length.
71        return len(tcpip_address.ip_address) == 4
72
73    @classmethod
74    def sock_address_to_str(cls, address: Address):
75        if cls.is_sock_tcpip_address(address):
76            tcpip_address: Address.TcpIpAddress = address.tcpip_address
77            if cls.is_ipv4(tcpip_address):
78                ip = ipaddress.IPv4Address(tcpip_address.ip_address)
79            else:
80                ip = ipaddress.IPv6Address(tcpip_address.ip_address)
81            return f'{ip}:{tcpip_address.port}'
82        else:
83            raise NotImplementedError('Only tcpip_address implemented')
84
85    @classmethod
86    def sock_addresses_pretty(cls, socket: Socket):
87        return (f'local={cls.sock_address_to_str(socket.local)}, '
88                f'remote={cls.sock_address_to_str(socket.remote)}')
89
90    @staticmethod
91    def find_server_socket_matching_client(server_sockets: Iterator[Socket],
92                                           client_socket: Socket) -> Socket:
93        for server_socket in server_sockets:
94            if server_socket.remote == client_socket.local:
95                return server_socket
96        return None
97
98    def find_channels_for_target(self, target: str,
99                                 **kwargs) -> Iterator[Channel]:
100        return (channel for channel in self.list_channels(**kwargs)
101                if channel.data.target == target)
102
103    def find_server_listening_on_port(self, port: int,
104                                      **kwargs) -> Optional[Server]:
105        for server in self.list_servers(**kwargs):
106            listen_socket_ref: SocketRef
107            for listen_socket_ref in server.listen_socket:
108                listen_socket = self.get_socket(listen_socket_ref.socket_id,
109                                                **kwargs)
110                listen_address: Address = listen_socket.local
111                if (self.is_sock_tcpip_address(listen_address) and
112                        listen_address.tcpip_address.port == port):
113                    return server
114        return None
115
116    def list_channels(self, **kwargs) -> Iterator[Channel]:
117        """
118        Iterate over all pages of all root channels.
119
120        Root channels are those which application has directly created.
121        This does not include subchannels nor non-top level channels.
122        """
123        start: int = -1
124        response: Optional[_GetTopChannelsResponse] = None
125        while start < 0 or not response.end:
126            # From proto: To request subsequent pages, the client generates this
127            # value by adding 1 to the highest seen result ID.
128            start += 1
129            response = self.call_unary_with_deadline(
130                rpc='GetTopChannels',
131                req=_GetTopChannelsRequest(start_channel_id=start),
132                **kwargs)
133            for channel in response.channel:
134                start = max(start, channel.ref.channel_id)
135                yield channel
136
137    def list_servers(self, **kwargs) -> Iterator[Server]:
138        """Iterate over all pages of all servers that exist in the process."""
139        start: int = -1
140        response: Optional[_GetServersResponse] = None
141        while start < 0 or not response.end:
142            # From proto: To request subsequent pages, the client generates this
143            # value by adding 1 to the highest seen result ID.
144            start += 1
145            response = self.call_unary_with_deadline(
146                rpc='GetServers',
147                req=_GetServersRequest(start_server_id=start),
148                **kwargs)
149            for server in response.server:
150                start = max(start, server.ref.server_id)
151                yield server
152
153    def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
154        """List all server sockets that exist in server process.
155
156        Iterating over the results will resolve additional pages automatically.
157        """
158        start: int = -1
159        response: Optional[_GetServerSocketsResponse] = None
160        while start < 0 or not response.end:
161            # From proto: To request subsequent pages, the client generates this
162            # value by adding 1 to the highest seen result ID.
163            start += 1
164            response = self.call_unary_with_deadline(
165                rpc='GetServerSockets',
166                req=_GetServerSocketsRequest(server_id=server.ref.server_id,
167                                             start_socket_id=start),
168                **kwargs)
169            socket_ref: SocketRef
170            for socket_ref in response.socket_ref:
171                start = max(start, socket_ref.socket_id)
172                # Yield actual socket
173                yield self.get_socket(socket_ref.socket_id, **kwargs)
174
175    def list_channel_sockets(self, channel: Channel,
176                             **kwargs) -> Iterator[Socket]:
177        """List all sockets of all subchannels of a given channel."""
178        for subchannel in self.list_channel_subchannels(channel, **kwargs):
179            yield from self.list_subchannels_sockets(subchannel, **kwargs)
180
181    def list_channel_subchannels(self, channel: Channel,
182                                 **kwargs) -> Iterator[Subchannel]:
183        """List all subchannels of a given channel."""
184        for subchannel_ref in channel.subchannel_ref:
185            yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
186
187    def list_subchannels_sockets(self, subchannel: Subchannel,
188                                 **kwargs) -> Iterator[Socket]:
189        """List all sockets of a given subchannel."""
190        for socket_ref in subchannel.socket_ref:
191            yield self.get_socket(socket_ref.socket_id, **kwargs)
192
193    def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
194        """Return a single Subchannel, otherwise raises RpcError."""
195        response: _GetSubchannelResponse = self.call_unary_with_deadline(
196            rpc='GetSubchannel',
197            req=_GetSubchannelRequest(subchannel_id=subchannel_id),
198            **kwargs)
199        return response.subchannel
200
201    def get_socket(self, socket_id, **kwargs) -> Socket:
202        """Return a single Socket, otherwise raises RpcError."""
203        response: _GetSocketResponse = self.call_unary_with_deadline(
204            rpc='GetSocket',
205            req=_GetSocketRequest(socket_id=socket_id),
206            **kwargs)
207        return response.socket
208