• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Types representing the basic pw_rpc concepts: channel, service, method."""
15
16import abc
17from dataclasses import dataclass
18import enum
19from inspect import Parameter
20from typing import (
21    Any,
22    Callable,
23    Collection,
24    Dict,
25    Generic,
26    Iterable,
27    Iterator,
28    Optional,
29    Tuple,
30    TypeVar,
31    Union,
32)
33
34from google.protobuf import descriptor_pb2, message_factory
35from google.protobuf.descriptor import (
36    FieldDescriptor,
37    MethodDescriptor,
38    ServiceDescriptor,
39)
40from google.protobuf.message import Message
41from pw_protobuf_compiler import python_protos
42
43from pw_rpc import ids
44
45
46@dataclass(frozen=True)
47class Channel:
48    id: int
49    output: Callable[[bytes], Any]
50
51    def __repr__(self) -> str:
52        return f'Channel({self.id})'
53
54
55class ChannelManipulator(abc.ABC):
56    """A a pipe interface that may manipulate packets before they're sent.
57
58    ``ChannelManipulator``s allow application-specific packet handling to be
59    injected into the packet processing pipeline for an ingress or egress
60    channel-like pathway. This is particularly useful for integration testing
61    resilience to things like packet loss on a usually-reliable transport. RPC
62    server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an
63    opportunity to inject a ``ChannelManipulator`` for this use case.
64
65    A ``ChannelManipulator`` should not modify send_packet, as the consumer of a
66    ``ChannelManipulator`` will use ``send_packet`` to insert the provided
67    ``ChannelManipulator`` into a packet processing path.
68
69    For example:
70
71    .. code-block:: python
72
73      class PacketLogger(ChannelManipulator):
74          def process_and_send(self, packet: bytes) -> None:
75              _LOG.debug('Received packet with payload: %s', str(packet))
76              self.send_packet(packet)
77
78
79      packet_logger = PacketLogger()
80
81      # Configure actual send command.
82      packet_logger.send_packet = socket.sendall
83
84      # Route the output channel through the PacketLogger channel manipulator.
85      channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger))
86
87      # Create a RPC client.
88      client = HdlcRpcClient(socket.read, protos, channels, stdout)
89    """
90
91    def __init__(self):
92        self.send_packet: Callable[[bytes], Any] = lambda _: None
93
94    @abc.abstractmethod
95    def process_and_send(self, packet: bytes) -> None:
96        """Processes an incoming packet before optionally sending it.
97
98        Implementations of this method may send the processed packet, multiple
99        packets, or no packets at all via the registered `send_packet()`
100        handler.
101        """
102
103    def __call__(self, data: bytes) -> None:
104        self.process_and_send(data)
105
106
107@dataclass(frozen=True, eq=False)
108class Service:
109    """Describes an RPC service."""
110
111    _descriptor: ServiceDescriptor
112    id: int
113    methods: 'Methods'
114
115    @property
116    def name(self):
117        return self._descriptor.name
118
119    @property
120    def full_name(self):
121        return self._descriptor.full_name
122
123    @property
124    def package(self):
125        return self._descriptor.file.package
126
127    @classmethod
128    def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service':
129        service = cls(
130            descriptor,
131            ids.calculate(descriptor.full_name),
132            None,  # type: ignore[arg-type]
133        )
134        object.__setattr__(
135            service,
136            'methods',
137            Methods(
138                Method.from_descriptor(method_descriptor, service)
139                for method_descriptor in descriptor.methods
140            ),
141        )
142
143        return service
144
145    def __repr__(self) -> str:
146        return f'Service({self.full_name!r})'
147
148    def __str__(self) -> str:
149        return self.full_name
150
151
152def _streaming_attributes(method) -> Tuple[bool, bool]:
153    # TODO(hepler): Investigate adding server_streaming and client_streaming
154    #     attributes to the generated protobuf code. As a workaround,
155    #     deserialize the FileDescriptorProto to get that information.
156    service = method.containing_service
157
158    file_pb = descriptor_pb2.FileDescriptorProto()
159    file_pb.MergeFromString(service.file.serialized_pb)
160
161    method_pb = file_pb.service[service.index].method[
162        method.index
163    ]  # pylint: disable=no-member
164    return method_pb.server_streaming, method_pb.client_streaming
165
166
167_PROTO_FIELD_TYPES = {
168    FieldDescriptor.TYPE_BOOL: bool,
169    FieldDescriptor.TYPE_BYTES: bytes,
170    FieldDescriptor.TYPE_DOUBLE: float,
171    FieldDescriptor.TYPE_ENUM: int,
172    FieldDescriptor.TYPE_FIXED32: int,
173    FieldDescriptor.TYPE_FIXED64: int,
174    FieldDescriptor.TYPE_FLOAT: float,
175    FieldDescriptor.TYPE_INT32: int,
176    FieldDescriptor.TYPE_INT64: int,
177    FieldDescriptor.TYPE_SFIXED32: int,
178    FieldDescriptor.TYPE_SFIXED64: int,
179    FieldDescriptor.TYPE_SINT32: int,
180    FieldDescriptor.TYPE_SINT64: int,
181    FieldDescriptor.TYPE_STRING: str,
182    FieldDescriptor.TYPE_UINT32: int,
183    FieldDescriptor.TYPE_UINT64: int,
184    # These types are not annotated:
185    # FieldDescriptor.TYPE_GROUP = 10
186    # FieldDescriptor.TYPE_MESSAGE = 11
187}
188
189
190def _field_type_annotation(field: FieldDescriptor):
191    """Creates a field type annotation to use in the help message only."""
192    if field.type == FieldDescriptor.TYPE_MESSAGE:
193        annotation = message_factory.MessageFactory(
194            field.message_type.file.pool
195        ).GetPrototype(field.message_type)
196    else:
197        annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
198
199    if field.label == FieldDescriptor.LABEL_REPEATED:
200        return Iterable[annotation]  # type: ignore[valid-type]
201
202    return annotation
203
204
205def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]:
206    """Yields argument strings for proto fields for use in a help message."""
207    for field in proto_message.DESCRIPTOR.fields:
208        if field.type == FieldDescriptor.TYPE_ENUM:
209            value = field.enum_type.values_by_number[field.default_value].name
210            type_name = field.enum_type.full_name
211            value = f'{type_name.rsplit(".", 1)[0]}.{value}'
212        else:
213            type_name = _PROTO_FIELD_TYPES[field.type].__name__
214            value = repr(field.default_value)
215
216        if annotations:
217            yield f'{field.name}: {type_name} = {value}'
218        else:
219            yield f'{field.name}={value}'
220
221
222def _message_is_type(proto, expected_type) -> bool:
223    """Returns true if the protobuf instance is the expected type."""
224    # Getting protobuf classes from google.protobuf.message_factory may create a
225    # new, unique generated proto class. Any generated classes for a particular
226    # proto message share the same MessageDescriptor instance and are
227    # interchangeable, so check the descriptors in addition to the types.
228    return isinstance(proto, expected_type) or (
229        isinstance(proto, Message)
230        and proto.DESCRIPTOR is expected_type.DESCRIPTOR
231    )
232
233
234@dataclass(frozen=True, eq=False)
235class Method:
236    """Describes a method in a service."""
237
238    _descriptor: MethodDescriptor
239    service: Service
240    id: int
241    server_streaming: bool
242    client_streaming: bool
243    request_type: Any
244    response_type: Any
245
246    @classmethod
247    def from_descriptor(cls, descriptor: MethodDescriptor, service: Service):
248        input_factory = message_factory.MessageFactory(
249            descriptor.input_type.file.pool
250        )
251        output_factory = message_factory.MessageFactory(
252            descriptor.output_type.file.pool
253        )
254        return Method(
255            descriptor,
256            service,
257            ids.calculate(descriptor.name),
258            *_streaming_attributes(descriptor),
259            input_factory.GetPrototype(descriptor.input_type),
260            output_factory.GetPrototype(descriptor.output_type),
261        )
262
263    class Type(enum.Enum):
264        UNARY = 0
265        SERVER_STREAMING = 1
266        CLIENT_STREAMING = 2
267        BIDIRECTIONAL_STREAMING = 3
268
269        def sentence_name(self) -> str:
270            return self.name.lower().replace(
271                '_', ' '
272            )  # pylint: disable=no-member
273
274    @property
275    def name(self) -> str:
276        return self._descriptor.name
277
278    @property
279    def full_name(self) -> str:
280        return self._descriptor.full_name
281
282    @property
283    def package(self) -> str:
284        return self._descriptor.containing_service.file.package
285
286    @property
287    def type(self) -> 'Method.Type':
288        if self.server_streaming and self.client_streaming:
289            return self.Type.BIDIRECTIONAL_STREAMING
290
291        if self.server_streaming:
292            return self.Type.SERVER_STREAMING
293
294        if self.client_streaming:
295            return self.Type.CLIENT_STREAMING
296
297        return self.Type.UNARY
298
299    def get_request(
300        self, proto: Optional[Message], proto_kwargs: Optional[Dict[str, Any]]
301    ) -> Message:
302        """Returns a request_type protobuf message.
303
304        The client implementation may use this to support providing a request
305        as either a message object or as keyword arguments for the message's
306        fields (but not both).
307        """
308        if proto_kwargs is None:
309            proto_kwargs = {}
310
311        if proto and proto_kwargs:
312            proto_str = repr(proto).strip() or "''"
313            raise TypeError(
314                'Requests must be provided either as a message object or a '
315                'series of keyword args, but both were provided '
316                f"({proto_str} and {proto_kwargs!r})"
317            )
318
319        if proto is None:
320            return self.request_type(**proto_kwargs)
321
322        if not _message_is_type(proto, self.request_type):
323            try:
324                bad_type = proto.DESCRIPTOR.full_name
325            except AttributeError:
326                bad_type = type(proto).__name__
327
328            raise TypeError(
329                f'Expected a message of type '
330                f'{self.request_type.DESCRIPTOR.full_name}, '
331                f'got {bad_type}'
332            )
333
334        return proto
335
336    def request_parameters(self) -> Iterator[Parameter]:
337        """Yields inspect.Parameters corresponding to the request's fields.
338
339        This can be used to make function signatures match the request proto.
340        """
341        for field in self.request_type.DESCRIPTOR.fields:
342            yield Parameter(
343                field.name,
344                Parameter.KEYWORD_ONLY,
345                annotation=_field_type_annotation(field),
346                default=field.default_value,
347            )
348
349    def __repr__(self) -> str:
350        req = self._method_parameter(self.request_type, self.client_streaming)
351        res = self._method_parameter(self.response_type, self.server_streaming)
352        return f'<{self.full_name}({req}) returns ({res})>'
353
354    def _method_parameter(self, proto, streaming: bool) -> str:
355        """Returns a description of the method's request or response type."""
356        stream = 'stream ' if streaming else ''
357
358        if proto.DESCRIPTOR.file.package == self.service.package:
359            return stream + proto.DESCRIPTOR.name
360
361        return stream + proto.DESCRIPTOR.full_name
362
363    def __str__(self) -> str:
364        return self.full_name
365
366
367T = TypeVar('T')
368
369
370def _name(item: Union[Service, Method]) -> str:
371    return item.full_name if isinstance(item, Service) else item.name
372
373
374class _AccessByName(Generic[T]):
375    """Wrapper for accessing types by name within a proto package structure."""
376
377    def __init__(self, name: str, item: T):
378        setattr(self, name, item)
379
380
381class ServiceAccessor(Collection[T]):
382    """Navigates RPC services by name or ID."""
383
384    def __init__(self, members, as_attrs: str = ''):
385        """Creates accessor from an {item: value} dict or [values] iterable."""
386        # If the members arg was passed as a [values] iterable, convert it to
387        # an equivalent dictionary.
388        if not isinstance(members, dict):
389            members = {m: m for m in members}
390
391        by_name = {_name(k): v for k, v in members.items()}
392        self._by_id = {k.id: v for k, v in members.items()}
393
394        if as_attrs == 'members':
395            for name, member in by_name.items():
396                setattr(self, name, member)
397        elif as_attrs == 'packages':
398            for package in python_protos.as_packages(
399                (m.package, _AccessByName(m.name, members[m])) for m in members
400            ).packages:
401                setattr(self, str(package), package)
402        elif as_attrs:
403            raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
404
405    def __getitem__(self, name_or_id: Union[str, int]):
406        """Accesses a service/method by the string name or ID."""
407        try:
408            return self._by_id[_id(name_or_id)]
409        except KeyError:
410            pass
411
412        name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
413        raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
414
415    def __iter__(self) -> Iterator[T]:
416        return iter(self._by_id.values())
417
418    def __len__(self) -> int:
419        return len(self._by_id)
420
421    def __contains__(self, name_or_id) -> bool:
422        return _id(name_or_id) in self._by_id
423
424    def __repr__(self) -> str:
425        members = ', '.join(repr(m) for m in self._by_id.values())
426        return f'{self.__class__.__name__}({members})'
427
428
429def _id(handle: Union[str, int]) -> int:
430    return ids.calculate(handle) if isinstance(handle, str) else handle
431
432
433class Methods(ServiceAccessor[Method]):
434    """A collection of Method descriptors in a Service."""
435
436    def __init__(self, method: Iterable[Method]):
437        super().__init__(method)
438
439
440class Services(ServiceAccessor[Service]):
441    """A collection of Service descriptors."""
442
443    def __init__(self, services: Iterable[Service]):
444        super().__init__(services)
445
446
447def get_method(service_accessor: ServiceAccessor, name: str):
448    """Returns a method matching the given full name in a ServiceAccessor.
449
450    Args:
451      name: name as package.Service/Method or package.Service.Method.
452
453    Raises:
454      ValueError: the method name is not properly formatted
455      KeyError: the method is not present
456    """
457    if '/' in name:
458        service_name, method_name = name.split('/')
459    else:
460        service_name, method_name = name.rsplit('.', 1)
461
462    service = service_accessor[service_name]
463    if isinstance(service, Service):
464        service = service.methods
465
466    return service[method_name]
467