• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Types representing the basic pw_rpc concepts: channel, service, method."""
15
16import abc
17from dataclasses import dataclass
18import enum
19from inspect import Parameter
20from typing import (Any, Callable, Collection, Dict, Generic, Iterable,
21                    Iterator, Optional, Tuple, TypeVar, Union)
22
23from google.protobuf import descriptor_pb2, message_factory
24from google.protobuf.descriptor import (FieldDescriptor, MethodDescriptor,
25                                        ServiceDescriptor)
26from google.protobuf.message import Message
27from pw_protobuf_compiler import python_protos
28
29from pw_rpc import ids
30
31
32@dataclass(frozen=True)
33class Channel:
34    id: int
35    output: Callable[[bytes], Any]
36
37    def __repr__(self) -> str:
38        return f'Channel({self.id})'
39
40
41class ChannelManipulator(abc.ABC):
42    """A a pipe interface that may manipulate packets before they're sent.
43
44    ``ChannelManipulator``s allow application-specific packet handling to be
45    injected into the packet processing pipeline for an ingress or egress
46    channel-like pathway. This is particularly useful for integration testing
47    resilience to things like packet loss on a usually-reliable transport. RPC
48    server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an
49    opportunity to inject a ``ChannelManipulator`` for this use case.
50
51    A ``ChannelManipulator`` should not modify send_packet, as the consumer of a
52    ``ChannelManipulator`` will use ``send_packet`` to insert the provided
53    ``ChannelManipulator`` into a packet processing path.
54
55    For example:
56
57    .. code-block:: python
58
59      class PacketLogger(ChannelManipulator):
60          def process_and_send(self, packet: bytes) -> None:
61              _LOG.debug('Received packet with payload: %s', str(packet))
62              self.send_packet(packet)
63
64
65      packet_logger = PacketLogger()
66
67      # Configure actual send command.
68      packet_logger.send_packet = socket.sendall
69
70      # Route the output channel through the PacketLogger channel manipulator.
71      channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger))
72
73      # Create a RPC client.
74      client = HdlcRpcClient(socket.read, protos, channels, stdout)
75    """
76    def __init__(self):
77        self.send_packet: Callable[[bytes], Any] = lambda _: None
78
79    @abc.abstractmethod
80    def process_and_send(self, packet: bytes) -> None:
81        """Processes an incoming packet before optionally sending it.
82
83        Implementations of this method may send the processed packet, multiple
84        packets, or no packets at all via the registered `send_packet()`
85        handler.
86        """
87
88    def __call__(self, data: bytes) -> None:
89        self.process_and_send(data)
90
91
92@dataclass(frozen=True, eq=False)
93class Service:
94    """Describes an RPC service."""
95    _descriptor: ServiceDescriptor
96    id: int
97    methods: 'Methods'
98
99    @property
100    def name(self):
101        return self._descriptor.name
102
103    @property
104    def full_name(self):
105        return self._descriptor.full_name
106
107    @property
108    def package(self):
109        return self._descriptor.file.package
110
111    @classmethod
112    def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service':
113        service = cls(descriptor, ids.calculate(descriptor.full_name),
114                      None)  # type: ignore[arg-type]
115        object.__setattr__(
116            service, 'methods',
117            Methods(
118                Method.from_descriptor(method_descriptor, service)
119                for method_descriptor in descriptor.methods))
120
121        return service
122
123    def __repr__(self) -> str:
124        return f'Service({self.full_name!r})'
125
126    def __str__(self) -> str:
127        return self.full_name
128
129
130def _streaming_attributes(method) -> Tuple[bool, bool]:
131    # TODO(hepler): Investigate adding server_streaming and client_streaming
132    #     attributes to the generated protobuf code. As a workaround,
133    #     deserialize the FileDescriptorProto to get that information.
134    service = method.containing_service
135
136    file_pb = descriptor_pb2.FileDescriptorProto()
137    file_pb.MergeFromString(service.file.serialized_pb)
138
139    method_pb = file_pb.service[service.index].method[method.index]  # pylint: disable=no-member
140    return method_pb.server_streaming, method_pb.client_streaming
141
142
143_PROTO_FIELD_TYPES = {
144    FieldDescriptor.TYPE_BOOL: bool,
145    FieldDescriptor.TYPE_BYTES: bytes,
146    FieldDescriptor.TYPE_DOUBLE: float,
147    FieldDescriptor.TYPE_ENUM: int,
148    FieldDescriptor.TYPE_FIXED32: int,
149    FieldDescriptor.TYPE_FIXED64: int,
150    FieldDescriptor.TYPE_FLOAT: float,
151    FieldDescriptor.TYPE_INT32: int,
152    FieldDescriptor.TYPE_INT64: int,
153    FieldDescriptor.TYPE_SFIXED32: int,
154    FieldDescriptor.TYPE_SFIXED64: int,
155    FieldDescriptor.TYPE_SINT32: int,
156    FieldDescriptor.TYPE_SINT64: int,
157    FieldDescriptor.TYPE_STRING: str,
158    FieldDescriptor.TYPE_UINT32: int,
159    FieldDescriptor.TYPE_UINT64: int,
160    # These types are not annotated:
161    # FieldDescriptor.TYPE_GROUP = 10
162    # FieldDescriptor.TYPE_MESSAGE = 11
163}
164
165
166def _field_type_annotation(field: FieldDescriptor):
167    """Creates a field type annotation to use in the help message only."""
168    if field.type == FieldDescriptor.TYPE_MESSAGE:
169        annotation = message_factory.MessageFactory(
170            field.message_type.file.pool).GetPrototype(field.message_type)
171    else:
172        annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
173
174    if field.label == FieldDescriptor.LABEL_REPEATED:
175        return Iterable[annotation]  # type: ignore[valid-type]
176
177    return annotation
178
179
180def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]:
181    """Yields argument strings for proto fields for use in a help message."""
182    for field in proto_message.DESCRIPTOR.fields:
183        if field.type == FieldDescriptor.TYPE_ENUM:
184            value = field.enum_type.values_by_number[field.default_value].name
185            type_name = field.enum_type.full_name
186            value = f'{type_name.rsplit(".", 1)[0]}.{value}'
187        else:
188            type_name = _PROTO_FIELD_TYPES[field.type].__name__
189            value = repr(field.default_value)
190
191        if annotations:
192            yield f'{field.name}: {type_name} = {value}'
193        else:
194            yield f'{field.name}={value}'
195
196
197def _message_is_type(proto, expected_type) -> bool:
198    """Returns true if the protobuf instance is the expected type."""
199    # Getting protobuf classes from google.protobuf.message_factory may create a
200    # new, unique generated proto class. Any generated classes for a particular
201    # proto message share the same MessageDescriptor instance and are
202    # interchangeable, so check the descriptors in addition to the types.
203    return isinstance(proto, expected_type) or (isinstance(
204        proto, Message) and proto.DESCRIPTOR is expected_type.DESCRIPTOR)
205
206
207@dataclass(frozen=True, eq=False)
208class Method:
209    """Describes a method in a service."""
210
211    _descriptor: MethodDescriptor
212    service: Service
213    id: int
214    server_streaming: bool
215    client_streaming: bool
216    request_type: Any
217    response_type: Any
218
219    @classmethod
220    def from_descriptor(cls, descriptor: MethodDescriptor, service: Service):
221        input_factory = message_factory.MessageFactory(
222            descriptor.input_type.file.pool)
223        output_factory = message_factory.MessageFactory(
224            descriptor.output_type.file.pool)
225        return Method(
226            descriptor,
227            service,
228            ids.calculate(descriptor.name),
229            *_streaming_attributes(descriptor),
230            input_factory.GetPrototype(descriptor.input_type),
231            output_factory.GetPrototype(descriptor.output_type),
232        )
233
234    class Type(enum.Enum):
235        UNARY = 0
236        SERVER_STREAMING = 1
237        CLIENT_STREAMING = 2
238        BIDIRECTIONAL_STREAMING = 3
239
240        def sentence_name(self) -> str:
241            return self.name.lower().replace('_', ' ')  # pylint: disable=no-member
242
243    @property
244    def name(self) -> str:
245        return self._descriptor.name
246
247    @property
248    def full_name(self) -> str:
249        return self._descriptor.full_name
250
251    @property
252    def package(self) -> str:
253        return self._descriptor.containing_service.file.package
254
255    @property
256    def type(self) -> 'Method.Type':
257        if self.server_streaming and self.client_streaming:
258            return self.Type.BIDIRECTIONAL_STREAMING
259
260        if self.server_streaming:
261            return self.Type.SERVER_STREAMING
262
263        if self.client_streaming:
264            return self.Type.CLIENT_STREAMING
265
266        return self.Type.UNARY
267
268    def get_request(self, proto: Optional[Message],
269                    proto_kwargs: Optional[Dict[str, Any]]) -> Message:
270        """Returns a request_type protobuf message.
271
272        The client implementation may use this to support providing a request
273        as either a message object or as keyword arguments for the message's
274        fields (but not both).
275        """
276        if proto_kwargs is None:
277            proto_kwargs = {}
278
279        if proto and proto_kwargs:
280            proto_str = repr(proto).strip() or "''"
281            raise TypeError(
282                'Requests must be provided either as a message object or a '
283                'series of keyword args, but both were provided '
284                f"({proto_str} and {proto_kwargs!r})")
285
286        if proto is None:
287            return self.request_type(**proto_kwargs)
288
289        if not _message_is_type(proto, self.request_type):
290            try:
291                bad_type = proto.DESCRIPTOR.full_name
292            except AttributeError:
293                bad_type = type(proto).__name__
294
295            raise TypeError(f'Expected a message of type '
296                            f'{self.request_type.DESCRIPTOR.full_name}, '
297                            f'got {bad_type}')
298
299        return proto
300
301    def request_parameters(self) -> Iterator[Parameter]:
302        """Yields inspect.Parameters corresponding to the request's fields.
303
304        This can be used to make function signatures match the request proto.
305        """
306        for field in self.request_type.DESCRIPTOR.fields:
307            yield Parameter(field.name,
308                            Parameter.KEYWORD_ONLY,
309                            annotation=_field_type_annotation(field),
310                            default=field.default_value)
311
312    def __repr__(self) -> str:
313        req = self._method_parameter(self.request_type, self.client_streaming)
314        res = self._method_parameter(self.response_type, self.server_streaming)
315        return f'<{self.full_name}({req}) returns ({res})>'
316
317    def _method_parameter(self, proto, streaming: bool) -> str:
318        """Returns a description of the method's request or response type."""
319        stream = 'stream ' if streaming else ''
320
321        if proto.DESCRIPTOR.file.package == self.service.package:
322            return stream + proto.DESCRIPTOR.name
323
324        return stream + proto.DESCRIPTOR.full_name
325
326    def __str__(self) -> str:
327        return self.full_name
328
329
330T = TypeVar('T')
331
332
333def _name(item: Union[Service, Method]) -> str:
334    return item.full_name if isinstance(item, Service) else item.name
335
336
337class _AccessByName(Generic[T]):
338    """Wrapper for accessing types by name within a proto package structure."""
339    def __init__(self, name: str, item: T):
340        setattr(self, name, item)
341
342
343class ServiceAccessor(Collection[T]):
344    """Navigates RPC services by name or ID."""
345    def __init__(self, members, as_attrs: str = ''):
346        """Creates accessor from an {item: value} dict or [values] iterable."""
347        # If the members arg was passed as a [values] iterable, convert it to
348        # an equivalent dictionary.
349        if not isinstance(members, dict):
350            members = {m: m for m in members}
351
352        by_name = {_name(k): v for k, v in members.items()}
353        self._by_id = {k.id: v for k, v in members.items()}
354
355        if as_attrs == 'members':
356            for name, member in by_name.items():
357                setattr(self, name, member)
358        elif as_attrs == 'packages':
359            for package in python_protos.as_packages(
360                (m.package, _AccessByName(m.name, members[m]))
361                    for m in members).packages:
362                setattr(self, str(package), package)
363        elif as_attrs:
364            raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
365
366    def __getitem__(self, name_or_id: Union[str, int]):
367        """Accesses a service/method by the string name or ID."""
368        try:
369            return self._by_id[_id(name_or_id)]
370        except KeyError:
371            pass
372
373        name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
374        raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
375
376    def __iter__(self) -> Iterator[T]:
377        return iter(self._by_id.values())
378
379    def __len__(self) -> int:
380        return len(self._by_id)
381
382    def __contains__(self, name_or_id) -> bool:
383        return _id(name_or_id) in self._by_id
384
385    def __repr__(self) -> str:
386        members = ', '.join(repr(m) for m in self._by_id.values())
387        return f'{self.__class__.__name__}({members})'
388
389
390def _id(handle: Union[str, int]) -> int:
391    return ids.calculate(handle) if isinstance(handle, str) else handle
392
393
394class Methods(ServiceAccessor[Method]):
395    """A collection of Method descriptors in a Service."""
396    def __init__(self, method: Iterable[Method]):
397        super().__init__(method)
398
399
400class Services(ServiceAccessor[Service]):
401    """A collection of Service descriptors."""
402    def __init__(self, services: Iterable[Service]):
403        super().__init__(services)
404
405
406def get_method(service_accessor: ServiceAccessor, name: str):
407    """Returns a method matching the given full name in a ServiceAccessor.
408
409    Args:
410      name: name as package.Service/Method or package.Service.Method.
411
412    Raises:
413      ValueError: the method name is not properly formatted
414      KeyError: the method is not present
415    """
416    if '/' in name:
417        service_name, method_name = name.split('/')
418    else:
419        service_name, method_name = name.rsplit('.', 1)
420
421    service = service_accessor[service_name]
422    if isinstance(service, Service):
423        service = service.methods
424
425    return service[method_name]
426