• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Types representing the basic pw_rpc concepts: channel, service, method."""
15
16from dataclasses import dataclass
17import enum
18from inspect import Parameter
19from typing import (Any, Callable, Collection, Dict, Generic, Iterable,
20                    Iterator, Tuple, TypeVar, Union)
21
22from google.protobuf import descriptor_pb2, message_factory
23from google.protobuf.descriptor import (FieldDescriptor, MethodDescriptor,
24                                        ServiceDescriptor)
25from google.protobuf.message import Message
26from pw_protobuf_compiler import python_protos
27
28from pw_rpc import ids
29
30
31@dataclass(frozen=True)
32class Channel:
33    id: int
34    output: Callable[[bytes], Any]
35
36    def __repr__(self) -> str:
37        return f'Channel({self.id})'
38
39
40@dataclass(frozen=True, eq=False)
41class Service:
42    """Describes an RPC service."""
43    _descriptor: ServiceDescriptor
44    id: int
45    methods: 'Methods'
46
47    @property
48    def name(self):
49        return self._descriptor.name
50
51    @property
52    def full_name(self):
53        return self._descriptor.full_name
54
55    @property
56    def package(self):
57        return self._descriptor.file.package
58
59    @classmethod
60    def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service':
61        service = cls(descriptor, ids.calculate(descriptor.full_name),
62                      None)  # type: ignore[arg-type]
63        object.__setattr__(
64            service, 'methods',
65            Methods(
66                Method.from_descriptor(method_descriptor, service)
67                for method_descriptor in descriptor.methods))
68
69        return service
70
71    def __repr__(self) -> str:
72        return f'Service({self.full_name!r})'
73
74    def __str__(self) -> str:
75        return self.full_name
76
77
78def _streaming_attributes(method) -> Tuple[bool, bool]:
79    # TODO(hepler): Investigate adding server_streaming and client_streaming
80    #     attributes to the generated protobuf code. As a workaround,
81    #     deserialize the FileDescriptorProto to get that information.
82    service = method.containing_service
83
84    file_pb = descriptor_pb2.FileDescriptorProto()
85    file_pb.MergeFromString(service.file.serialized_pb)
86
87    method_pb = file_pb.service[service.index].method[method.index]  # pylint: disable=no-member
88    return method_pb.server_streaming, method_pb.client_streaming
89
90
91_PROTO_FIELD_TYPES = {
92    FieldDescriptor.TYPE_BOOL: bool,
93    FieldDescriptor.TYPE_BYTES: bytes,
94    FieldDescriptor.TYPE_DOUBLE: float,
95    FieldDescriptor.TYPE_ENUM: int,
96    FieldDescriptor.TYPE_FIXED32: int,
97    FieldDescriptor.TYPE_FIXED64: int,
98    FieldDescriptor.TYPE_FLOAT: float,
99    FieldDescriptor.TYPE_INT32: int,
100    FieldDescriptor.TYPE_INT64: int,
101    FieldDescriptor.TYPE_SFIXED32: int,
102    FieldDescriptor.TYPE_SFIXED64: int,
103    FieldDescriptor.TYPE_SINT32: int,
104    FieldDescriptor.TYPE_SINT64: int,
105    FieldDescriptor.TYPE_STRING: str,
106    FieldDescriptor.TYPE_UINT32: int,
107    FieldDescriptor.TYPE_UINT64: int,
108    # These types are not annotated:
109    # FieldDescriptor.TYPE_GROUP = 10
110    # FieldDescriptor.TYPE_MESSAGE = 11
111}
112
113
114def _field_type_annotation(field: FieldDescriptor):
115    """Creates a field type annotation to use in the help message only."""
116    if field.type == FieldDescriptor.TYPE_MESSAGE:
117        annotation = message_factory.MessageFactory(
118            field.message_type.file.pool).GetPrototype(field.message_type)
119    else:
120        annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty)
121
122    if field.label == FieldDescriptor.LABEL_REPEATED:
123        return Iterable[annotation]  # type: ignore[valid-type]
124
125    return annotation
126
127
128def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]:
129    """Yields argument strings for proto fields for use in a help message."""
130    for field in proto_message.DESCRIPTOR.fields:
131        if field.type == FieldDescriptor.TYPE_ENUM:
132            value = field.enum_type.values_by_number[field.default_value].name
133            type_name = field.enum_type.full_name
134            value = f'{type_name.rsplit(".", 1)[0]}.{value}'
135        else:
136            type_name = _PROTO_FIELD_TYPES[field.type].__name__
137            value = repr(field.default_value)
138
139        if annotations:
140            yield f'{field.name}: {type_name} = {value}'
141        else:
142            yield f'{field.name}={value}'
143
144
145def _message_is_type(proto, expected_type) -> bool:
146    """Returns true if the protobuf instance is the expected type."""
147    # Getting protobuf classes from google.protobuf.message_factory may create a
148    # new, unique generated proto class. Any generated classes for a particular
149    # proto message share the same MessageDescriptor instance and are
150    # interchangeable, so check the descriptors in addition to the types.
151    return isinstance(proto, expected_type) or (isinstance(
152        proto, Message) and proto.DESCRIPTOR is expected_type.DESCRIPTOR)
153
154
155@dataclass(frozen=True, eq=False)
156class Method:
157    """Describes a method in a service."""
158
159    _descriptor: MethodDescriptor
160    service: Service
161    id: int
162    server_streaming: bool
163    client_streaming: bool
164    request_type: Any
165    response_type: Any
166
167    @classmethod
168    def from_descriptor(cls, descriptor: MethodDescriptor, service: Service):
169        input_factory = message_factory.MessageFactory(
170            descriptor.input_type.file.pool)
171        output_factory = message_factory.MessageFactory(
172            descriptor.output_type.file.pool)
173        return Method(
174            descriptor,
175            service,
176            ids.calculate(descriptor.name),
177            *_streaming_attributes(descriptor),
178            input_factory.GetPrototype(descriptor.input_type),
179            output_factory.GetPrototype(descriptor.output_type),
180        )
181
182    class Type(enum.Enum):
183        UNARY = 0
184        SERVER_STREAMING = 1
185        CLIENT_STREAMING = 2
186        BIDIRECTIONAL_STREAMING = 3
187
188        def sentence_name(self) -> str:
189            return self.name.lower().replace('_', ' ')  # pylint: disable=no-member
190
191    @property
192    def name(self) -> str:
193        return self._descriptor.name
194
195    @property
196    def full_name(self) -> str:
197        return self._descriptor.full_name
198
199    @property
200    def package(self) -> str:
201        return self._descriptor.containing_service.file.package
202
203    @property
204    def type(self) -> 'Method.Type':
205        if self.server_streaming and self.client_streaming:
206            return self.Type.BIDIRECTIONAL_STREAMING
207
208        if self.server_streaming:
209            return self.Type.SERVER_STREAMING
210
211        if self.client_streaming:
212            return self.Type.CLIENT_STREAMING
213
214        return self.Type.UNARY
215
216    def get_request(self, proto, proto_kwargs: Dict[str, Any]):
217        """Returns a request_type protobuf message.
218
219        The client implementation may use this to support providing a request
220        as either a message object or as keyword arguments for the message's
221        fields (but not both).
222        """
223        if proto and proto_kwargs:
224            proto_str = repr(proto).strip() or "''"
225            raise TypeError(
226                'Requests must be provided either as a message object or a '
227                'series of keyword args, but both were provided '
228                f"({proto_str} and {proto_kwargs!r})")
229
230        if proto is None:
231            return self.request_type(**proto_kwargs)
232
233        if not _message_is_type(proto, self.request_type):
234            try:
235                bad_type = proto.DESCRIPTOR.full_name
236            except AttributeError:
237                bad_type = type(proto).__name__
238
239            raise TypeError(f'Expected a message of type '
240                            f'{self.request_type.DESCRIPTOR.full_name}, '
241                            f'got {bad_type}')
242
243        return proto
244
245    def request_parameters(self) -> Iterator[Parameter]:
246        """Yields inspect.Parameters corresponding to the request's fields.
247
248        This can be used to make function signatures match the request proto.
249        """
250        for field in self.request_type.DESCRIPTOR.fields:
251            yield Parameter(field.name,
252                            Parameter.KEYWORD_ONLY,
253                            annotation=_field_type_annotation(field),
254                            default=field.default_value)
255
256    def __repr__(self) -> str:
257        req = self._method_parameter(self.request_type, self.client_streaming)
258        res = self._method_parameter(self.response_type, self.server_streaming)
259        return f'<{self.full_name}({req}) returns ({res})>'
260
261    def _method_parameter(self, proto, streaming: bool) -> str:
262        """Returns a description of the method's request or response type."""
263        stream = 'stream ' if streaming else ''
264
265        if proto.DESCRIPTOR.file.package == self.service.package:
266            return stream + proto.DESCRIPTOR.name
267
268        return stream + proto.DESCRIPTOR.full_name
269
270    def __str__(self) -> str:
271        return self.full_name
272
273
274T = TypeVar('T')
275
276
277def _name(item: Union[Service, Method]) -> str:
278    return item.full_name if isinstance(item, Service) else item.name
279
280
281class _AccessByName(Generic[T]):
282    """Wrapper for accessing types by name within a proto package structure."""
283    def __init__(self, name: str, item: T):
284        setattr(self, name, item)
285
286
287class ServiceAccessor(Collection[T]):
288    """Navigates RPC services by name or ID."""
289    def __init__(self, members, as_attrs: str = ''):
290        """Creates accessor from an {item: value} dict or [values] iterable."""
291        # If the members arg was passed as a [values] iterable, convert it to
292        # an equivalent dictionary.
293        if not isinstance(members, dict):
294            members = {m: m for m in members}
295
296        by_name = {_name(k): v for k, v in members.items()}
297        self._by_id = {k.id: v for k, v in members.items()}
298
299        if as_attrs == 'members':
300            for name, member in by_name.items():
301                setattr(self, name, member)
302        elif as_attrs == 'packages':
303            for package in python_protos.as_packages(
304                (m.package, _AccessByName(m.name, members[m]))
305                    for m in members).packages:
306                setattr(self, str(package), package)
307        elif as_attrs:
308            raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
309
310    def __getitem__(self, name_or_id: Union[str, int]):
311        """Accesses a service/method by the string name or ID."""
312        try:
313            return self._by_id[_id(name_or_id)]
314        except KeyError:
315            pass
316
317        name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
318        raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
319
320    def __iter__(self) -> Iterator[T]:
321        return iter(self._by_id.values())
322
323    def __len__(self) -> int:
324        return len(self._by_id)
325
326    def __contains__(self, name_or_id) -> bool:
327        return _id(name_or_id) in self._by_id
328
329    def __repr__(self) -> str:
330        members = ', '.join(repr(m) for m in self._by_id.values())
331        return f'{self.__class__.__name__}({members})'
332
333
334def _id(handle: Union[str, int]) -> int:
335    return ids.calculate(handle) if isinstance(handle, str) else handle
336
337
338class Methods(ServiceAccessor[Method]):
339    """A collection of Method descriptors in a Service."""
340    def __init__(self, method: Iterable[Method]):
341        super().__init__(method)
342
343
344class Services(ServiceAccessor[Service]):
345    """A collection of Service descriptors."""
346    def __init__(self, services: Iterable[Service]):
347        super().__init__(services)
348
349
350def get_method(service_accessor: ServiceAccessor, name: str):
351    """Returns a method matching the given full name in a ServiceAccessor.
352
353    Args:
354      name: name as package.Service/Method or package.Service.Method.
355
356    Raises:
357      ValueError: the method name is not properly formatted
358      KeyError: the method is not present
359    """
360    if '/' in name:
361        service_name, method_name = name.split('/')
362    else:
363        service_name, method_name = name.rsplit('.', 1)
364
365    service = service_accessor[service_name]
366    if isinstance(service, Service):
367        service = service.methods
368
369    return service[method_name]
370