1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Types representing the basic pw_rpc concepts: channel, service, method.""" 15 16import abc 17from dataclasses import dataclass 18import enum 19from inspect import Parameter 20from typing import (Any, Callable, Collection, Dict, Generic, Iterable, 21 Iterator, Optional, Tuple, TypeVar, Union) 22 23from google.protobuf import descriptor_pb2, message_factory 24from google.protobuf.descriptor import (FieldDescriptor, MethodDescriptor, 25 ServiceDescriptor) 26from google.protobuf.message import Message 27from pw_protobuf_compiler import python_protos 28 29from pw_rpc import ids 30 31 32@dataclass(frozen=True) 33class Channel: 34 id: int 35 output: Callable[[bytes], Any] 36 37 def __repr__(self) -> str: 38 return f'Channel({self.id})' 39 40 41class ChannelManipulator(abc.ABC): 42 """A a pipe interface that may manipulate packets before they're sent. 43 44 ``ChannelManipulator``s allow application-specific packet handling to be 45 injected into the packet processing pipeline for an ingress or egress 46 channel-like pathway. This is particularly useful for integration testing 47 resilience to things like packet loss on a usually-reliable transport. RPC 48 server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an 49 opportunity to inject a ``ChannelManipulator`` for this use case. 50 51 A ``ChannelManipulator`` should not modify send_packet, as the consumer of a 52 ``ChannelManipulator`` will use ``send_packet`` to insert the provided 53 ``ChannelManipulator`` into a packet processing path. 54 55 For example: 56 57 .. code-block:: python 58 59 class PacketLogger(ChannelManipulator): 60 def process_and_send(self, packet: bytes) -> None: 61 _LOG.debug('Received packet with payload: %s', str(packet)) 62 self.send_packet(packet) 63 64 65 packet_logger = PacketLogger() 66 67 # Configure actual send command. 68 packet_logger.send_packet = socket.sendall 69 70 # Route the output channel through the PacketLogger channel manipulator. 71 channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger)) 72 73 # Create a RPC client. 74 client = HdlcRpcClient(socket.read, protos, channels, stdout) 75 """ 76 def __init__(self): 77 self.send_packet: Callable[[bytes], Any] = lambda _: None 78 79 @abc.abstractmethod 80 def process_and_send(self, packet: bytes) -> None: 81 """Processes an incoming packet before optionally sending it. 82 83 Implementations of this method may send the processed packet, multiple 84 packets, or no packets at all via the registered `send_packet()` 85 handler. 86 """ 87 88 def __call__(self, data: bytes) -> None: 89 self.process_and_send(data) 90 91 92@dataclass(frozen=True, eq=False) 93class Service: 94 """Describes an RPC service.""" 95 _descriptor: ServiceDescriptor 96 id: int 97 methods: 'Methods' 98 99 @property 100 def name(self): 101 return self._descriptor.name 102 103 @property 104 def full_name(self): 105 return self._descriptor.full_name 106 107 @property 108 def package(self): 109 return self._descriptor.file.package 110 111 @classmethod 112 def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service': 113 service = cls(descriptor, ids.calculate(descriptor.full_name), 114 None) # type: ignore[arg-type] 115 object.__setattr__( 116 service, 'methods', 117 Methods( 118 Method.from_descriptor(method_descriptor, service) 119 for method_descriptor in descriptor.methods)) 120 121 return service 122 123 def __repr__(self) -> str: 124 return f'Service({self.full_name!r})' 125 126 def __str__(self) -> str: 127 return self.full_name 128 129 130def _streaming_attributes(method) -> Tuple[bool, bool]: 131 # TODO(hepler): Investigate adding server_streaming and client_streaming 132 # attributes to the generated protobuf code. As a workaround, 133 # deserialize the FileDescriptorProto to get that information. 134 service = method.containing_service 135 136 file_pb = descriptor_pb2.FileDescriptorProto() 137 file_pb.MergeFromString(service.file.serialized_pb) 138 139 method_pb = file_pb.service[service.index].method[method.index] # pylint: disable=no-member 140 return method_pb.server_streaming, method_pb.client_streaming 141 142 143_PROTO_FIELD_TYPES = { 144 FieldDescriptor.TYPE_BOOL: bool, 145 FieldDescriptor.TYPE_BYTES: bytes, 146 FieldDescriptor.TYPE_DOUBLE: float, 147 FieldDescriptor.TYPE_ENUM: int, 148 FieldDescriptor.TYPE_FIXED32: int, 149 FieldDescriptor.TYPE_FIXED64: int, 150 FieldDescriptor.TYPE_FLOAT: float, 151 FieldDescriptor.TYPE_INT32: int, 152 FieldDescriptor.TYPE_INT64: int, 153 FieldDescriptor.TYPE_SFIXED32: int, 154 FieldDescriptor.TYPE_SFIXED64: int, 155 FieldDescriptor.TYPE_SINT32: int, 156 FieldDescriptor.TYPE_SINT64: int, 157 FieldDescriptor.TYPE_STRING: str, 158 FieldDescriptor.TYPE_UINT32: int, 159 FieldDescriptor.TYPE_UINT64: int, 160 # These types are not annotated: 161 # FieldDescriptor.TYPE_GROUP = 10 162 # FieldDescriptor.TYPE_MESSAGE = 11 163} 164 165 166def _field_type_annotation(field: FieldDescriptor): 167 """Creates a field type annotation to use in the help message only.""" 168 if field.type == FieldDescriptor.TYPE_MESSAGE: 169 annotation = message_factory.MessageFactory( 170 field.message_type.file.pool).GetPrototype(field.message_type) 171 else: 172 annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty) 173 174 if field.label == FieldDescriptor.LABEL_REPEATED: 175 return Iterable[annotation] # type: ignore[valid-type] 176 177 return annotation 178 179 180def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]: 181 """Yields argument strings for proto fields for use in a help message.""" 182 for field in proto_message.DESCRIPTOR.fields: 183 if field.type == FieldDescriptor.TYPE_ENUM: 184 value = field.enum_type.values_by_number[field.default_value].name 185 type_name = field.enum_type.full_name 186 value = f'{type_name.rsplit(".", 1)[0]}.{value}' 187 else: 188 type_name = _PROTO_FIELD_TYPES[field.type].__name__ 189 value = repr(field.default_value) 190 191 if annotations: 192 yield f'{field.name}: {type_name} = {value}' 193 else: 194 yield f'{field.name}={value}' 195 196 197def _message_is_type(proto, expected_type) -> bool: 198 """Returns true if the protobuf instance is the expected type.""" 199 # Getting protobuf classes from google.protobuf.message_factory may create a 200 # new, unique generated proto class. Any generated classes for a particular 201 # proto message share the same MessageDescriptor instance and are 202 # interchangeable, so check the descriptors in addition to the types. 203 return isinstance(proto, expected_type) or (isinstance( 204 proto, Message) and proto.DESCRIPTOR is expected_type.DESCRIPTOR) 205 206 207@dataclass(frozen=True, eq=False) 208class Method: 209 """Describes a method in a service.""" 210 211 _descriptor: MethodDescriptor 212 service: Service 213 id: int 214 server_streaming: bool 215 client_streaming: bool 216 request_type: Any 217 response_type: Any 218 219 @classmethod 220 def from_descriptor(cls, descriptor: MethodDescriptor, service: Service): 221 input_factory = message_factory.MessageFactory( 222 descriptor.input_type.file.pool) 223 output_factory = message_factory.MessageFactory( 224 descriptor.output_type.file.pool) 225 return Method( 226 descriptor, 227 service, 228 ids.calculate(descriptor.name), 229 *_streaming_attributes(descriptor), 230 input_factory.GetPrototype(descriptor.input_type), 231 output_factory.GetPrototype(descriptor.output_type), 232 ) 233 234 class Type(enum.Enum): 235 UNARY = 0 236 SERVER_STREAMING = 1 237 CLIENT_STREAMING = 2 238 BIDIRECTIONAL_STREAMING = 3 239 240 def sentence_name(self) -> str: 241 return self.name.lower().replace('_', ' ') # pylint: disable=no-member 242 243 @property 244 def name(self) -> str: 245 return self._descriptor.name 246 247 @property 248 def full_name(self) -> str: 249 return self._descriptor.full_name 250 251 @property 252 def package(self) -> str: 253 return self._descriptor.containing_service.file.package 254 255 @property 256 def type(self) -> 'Method.Type': 257 if self.server_streaming and self.client_streaming: 258 return self.Type.BIDIRECTIONAL_STREAMING 259 260 if self.server_streaming: 261 return self.Type.SERVER_STREAMING 262 263 if self.client_streaming: 264 return self.Type.CLIENT_STREAMING 265 266 return self.Type.UNARY 267 268 def get_request(self, proto: Optional[Message], 269 proto_kwargs: Optional[Dict[str, Any]]) -> Message: 270 """Returns a request_type protobuf message. 271 272 The client implementation may use this to support providing a request 273 as either a message object or as keyword arguments for the message's 274 fields (but not both). 275 """ 276 if proto_kwargs is None: 277 proto_kwargs = {} 278 279 if proto and proto_kwargs: 280 proto_str = repr(proto).strip() or "''" 281 raise TypeError( 282 'Requests must be provided either as a message object or a ' 283 'series of keyword args, but both were provided ' 284 f"({proto_str} and {proto_kwargs!r})") 285 286 if proto is None: 287 return self.request_type(**proto_kwargs) 288 289 if not _message_is_type(proto, self.request_type): 290 try: 291 bad_type = proto.DESCRIPTOR.full_name 292 except AttributeError: 293 bad_type = type(proto).__name__ 294 295 raise TypeError(f'Expected a message of type ' 296 f'{self.request_type.DESCRIPTOR.full_name}, ' 297 f'got {bad_type}') 298 299 return proto 300 301 def request_parameters(self) -> Iterator[Parameter]: 302 """Yields inspect.Parameters corresponding to the request's fields. 303 304 This can be used to make function signatures match the request proto. 305 """ 306 for field in self.request_type.DESCRIPTOR.fields: 307 yield Parameter(field.name, 308 Parameter.KEYWORD_ONLY, 309 annotation=_field_type_annotation(field), 310 default=field.default_value) 311 312 def __repr__(self) -> str: 313 req = self._method_parameter(self.request_type, self.client_streaming) 314 res = self._method_parameter(self.response_type, self.server_streaming) 315 return f'<{self.full_name}({req}) returns ({res})>' 316 317 def _method_parameter(self, proto, streaming: bool) -> str: 318 """Returns a description of the method's request or response type.""" 319 stream = 'stream ' if streaming else '' 320 321 if proto.DESCRIPTOR.file.package == self.service.package: 322 return stream + proto.DESCRIPTOR.name 323 324 return stream + proto.DESCRIPTOR.full_name 325 326 def __str__(self) -> str: 327 return self.full_name 328 329 330T = TypeVar('T') 331 332 333def _name(item: Union[Service, Method]) -> str: 334 return item.full_name if isinstance(item, Service) else item.name 335 336 337class _AccessByName(Generic[T]): 338 """Wrapper for accessing types by name within a proto package structure.""" 339 def __init__(self, name: str, item: T): 340 setattr(self, name, item) 341 342 343class ServiceAccessor(Collection[T]): 344 """Navigates RPC services by name or ID.""" 345 def __init__(self, members, as_attrs: str = ''): 346 """Creates accessor from an {item: value} dict or [values] iterable.""" 347 # If the members arg was passed as a [values] iterable, convert it to 348 # an equivalent dictionary. 349 if not isinstance(members, dict): 350 members = {m: m for m in members} 351 352 by_name = {_name(k): v for k, v in members.items()} 353 self._by_id = {k.id: v for k, v in members.items()} 354 355 if as_attrs == 'members': 356 for name, member in by_name.items(): 357 setattr(self, name, member) 358 elif as_attrs == 'packages': 359 for package in python_protos.as_packages( 360 (m.package, _AccessByName(m.name, members[m])) 361 for m in members).packages: 362 setattr(self, str(package), package) 363 elif as_attrs: 364 raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs') 365 366 def __getitem__(self, name_or_id: Union[str, int]): 367 """Accesses a service/method by the string name or ID.""" 368 try: 369 return self._by_id[_id(name_or_id)] 370 except KeyError: 371 pass 372 373 name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else '' 374 raise KeyError(f'Unknown ID {_id(name_or_id)}{name}') 375 376 def __iter__(self) -> Iterator[T]: 377 return iter(self._by_id.values()) 378 379 def __len__(self) -> int: 380 return len(self._by_id) 381 382 def __contains__(self, name_or_id) -> bool: 383 return _id(name_or_id) in self._by_id 384 385 def __repr__(self) -> str: 386 members = ', '.join(repr(m) for m in self._by_id.values()) 387 return f'{self.__class__.__name__}({members})' 388 389 390def _id(handle: Union[str, int]) -> int: 391 return ids.calculate(handle) if isinstance(handle, str) else handle 392 393 394class Methods(ServiceAccessor[Method]): 395 """A collection of Method descriptors in a Service.""" 396 def __init__(self, method: Iterable[Method]): 397 super().__init__(method) 398 399 400class Services(ServiceAccessor[Service]): 401 """A collection of Service descriptors.""" 402 def __init__(self, services: Iterable[Service]): 403 super().__init__(services) 404 405 406def get_method(service_accessor: ServiceAccessor, name: str): 407 """Returns a method matching the given full name in a ServiceAccessor. 408 409 Args: 410 name: name as package.Service/Method or package.Service.Method. 411 412 Raises: 413 ValueError: the method name is not properly formatted 414 KeyError: the method is not present 415 """ 416 if '/' in name: 417 service_name, method_name = name.split('/') 418 else: 419 service_name, method_name = name.rsplit('.', 1) 420 421 service = service_accessor[service_name] 422 if isinstance(service, Service): 423 service = service.methods 424 425 return service[method_name] 426