1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Types representing the basic pw_rpc concepts: channel, service, method.""" 15 16import abc 17from dataclasses import dataclass 18import enum 19from inspect import Parameter 20from typing import ( 21 Any, 22 Callable, 23 Collection, 24 Dict, 25 Generic, 26 Iterable, 27 Iterator, 28 Optional, 29 Tuple, 30 TypeVar, 31 Union, 32) 33 34from google.protobuf import descriptor_pb2, message_factory 35from google.protobuf.descriptor import ( 36 FieldDescriptor, 37 MethodDescriptor, 38 ServiceDescriptor, 39) 40from google.protobuf.message import Message 41from pw_protobuf_compiler import python_protos 42 43from pw_rpc import ids 44 45 46@dataclass(frozen=True) 47class Channel: 48 id: int 49 output: Callable[[bytes], Any] 50 51 def __repr__(self) -> str: 52 return f'Channel({self.id})' 53 54 55class ChannelManipulator(abc.ABC): 56 """A a pipe interface that may manipulate packets before they're sent. 57 58 ``ChannelManipulator``s allow application-specific packet handling to be 59 injected into the packet processing pipeline for an ingress or egress 60 channel-like pathway. This is particularly useful for integration testing 61 resilience to things like packet loss on a usually-reliable transport. RPC 62 server integrations (e.g. ``HdlcRpcLocalServerAndClient``) may provide an 63 opportunity to inject a ``ChannelManipulator`` for this use case. 64 65 A ``ChannelManipulator`` should not modify send_packet, as the consumer of a 66 ``ChannelManipulator`` will use ``send_packet`` to insert the provided 67 ``ChannelManipulator`` into a packet processing path. 68 69 For example: 70 71 .. code-block:: python 72 73 class PacketLogger(ChannelManipulator): 74 def process_and_send(self, packet: bytes) -> None: 75 _LOG.debug('Received packet with payload: %s', str(packet)) 76 self.send_packet(packet) 77 78 79 packet_logger = PacketLogger() 80 81 # Configure actual send command. 82 packet_logger.send_packet = socket.sendall 83 84 # Route the output channel through the PacketLogger channel manipulator. 85 channels = tuple(Channel(_DEFAULT_CHANNEL, packet_logger)) 86 87 # Create a RPC client. 88 client = HdlcRpcClient(socket.read, protos, channels, stdout) 89 """ 90 91 def __init__(self): 92 self.send_packet: Callable[[bytes], Any] = lambda _: None 93 94 @abc.abstractmethod 95 def process_and_send(self, packet: bytes) -> None: 96 """Processes an incoming packet before optionally sending it. 97 98 Implementations of this method may send the processed packet, multiple 99 packets, or no packets at all via the registered `send_packet()` 100 handler. 101 """ 102 103 def __call__(self, data: bytes) -> None: 104 self.process_and_send(data) 105 106 107@dataclass(frozen=True, eq=False) 108class Service: 109 """Describes an RPC service.""" 110 111 _descriptor: ServiceDescriptor 112 id: int 113 methods: 'Methods' 114 115 @property 116 def name(self): 117 return self._descriptor.name 118 119 @property 120 def full_name(self): 121 return self._descriptor.full_name 122 123 @property 124 def package(self): 125 return self._descriptor.file.package 126 127 @classmethod 128 def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service': 129 service = cls( 130 descriptor, 131 ids.calculate(descriptor.full_name), 132 None, # type: ignore[arg-type] 133 ) 134 object.__setattr__( 135 service, 136 'methods', 137 Methods( 138 Method.from_descriptor(method_descriptor, service) 139 for method_descriptor in descriptor.methods 140 ), 141 ) 142 143 return service 144 145 def __repr__(self) -> str: 146 return f'Service({self.full_name!r})' 147 148 def __str__(self) -> str: 149 return self.full_name 150 151 152def _streaming_attributes(method) -> Tuple[bool, bool]: 153 # TODO(hepler): Investigate adding server_streaming and client_streaming 154 # attributes to the generated protobuf code. As a workaround, 155 # deserialize the FileDescriptorProto to get that information. 156 service = method.containing_service 157 158 file_pb = descriptor_pb2.FileDescriptorProto() 159 file_pb.MergeFromString(service.file.serialized_pb) 160 161 method_pb = file_pb.service[service.index].method[ 162 method.index 163 ] # pylint: disable=no-member 164 return method_pb.server_streaming, method_pb.client_streaming 165 166 167_PROTO_FIELD_TYPES = { 168 FieldDescriptor.TYPE_BOOL: bool, 169 FieldDescriptor.TYPE_BYTES: bytes, 170 FieldDescriptor.TYPE_DOUBLE: float, 171 FieldDescriptor.TYPE_ENUM: int, 172 FieldDescriptor.TYPE_FIXED32: int, 173 FieldDescriptor.TYPE_FIXED64: int, 174 FieldDescriptor.TYPE_FLOAT: float, 175 FieldDescriptor.TYPE_INT32: int, 176 FieldDescriptor.TYPE_INT64: int, 177 FieldDescriptor.TYPE_SFIXED32: int, 178 FieldDescriptor.TYPE_SFIXED64: int, 179 FieldDescriptor.TYPE_SINT32: int, 180 FieldDescriptor.TYPE_SINT64: int, 181 FieldDescriptor.TYPE_STRING: str, 182 FieldDescriptor.TYPE_UINT32: int, 183 FieldDescriptor.TYPE_UINT64: int, 184 # These types are not annotated: 185 # FieldDescriptor.TYPE_GROUP = 10 186 # FieldDescriptor.TYPE_MESSAGE = 11 187} 188 189 190def _field_type_annotation(field: FieldDescriptor): 191 """Creates a field type annotation to use in the help message only.""" 192 if field.type == FieldDescriptor.TYPE_MESSAGE: 193 annotation = message_factory.MessageFactory( 194 field.message_type.file.pool 195 ).GetPrototype(field.message_type) 196 else: 197 annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty) 198 199 if field.label == FieldDescriptor.LABEL_REPEATED: 200 return Iterable[annotation] # type: ignore[valid-type] 201 202 return annotation 203 204 205def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]: 206 """Yields argument strings for proto fields for use in a help message.""" 207 for field in proto_message.DESCRIPTOR.fields: 208 if field.type == FieldDescriptor.TYPE_ENUM: 209 value = field.enum_type.values_by_number[field.default_value].name 210 type_name = field.enum_type.full_name 211 value = f'{type_name.rsplit(".", 1)[0]}.{value}' 212 else: 213 type_name = _PROTO_FIELD_TYPES[field.type].__name__ 214 value = repr(field.default_value) 215 216 if annotations: 217 yield f'{field.name}: {type_name} = {value}' 218 else: 219 yield f'{field.name}={value}' 220 221 222def _message_is_type(proto, expected_type) -> bool: 223 """Returns true if the protobuf instance is the expected type.""" 224 # Getting protobuf classes from google.protobuf.message_factory may create a 225 # new, unique generated proto class. Any generated classes for a particular 226 # proto message share the same MessageDescriptor instance and are 227 # interchangeable, so check the descriptors in addition to the types. 228 return isinstance(proto, expected_type) or ( 229 isinstance(proto, Message) 230 and proto.DESCRIPTOR is expected_type.DESCRIPTOR 231 ) 232 233 234@dataclass(frozen=True, eq=False) 235class Method: 236 """Describes a method in a service.""" 237 238 _descriptor: MethodDescriptor 239 service: Service 240 id: int 241 server_streaming: bool 242 client_streaming: bool 243 request_type: Any 244 response_type: Any 245 246 @classmethod 247 def from_descriptor(cls, descriptor: MethodDescriptor, service: Service): 248 input_factory = message_factory.MessageFactory( 249 descriptor.input_type.file.pool 250 ) 251 output_factory = message_factory.MessageFactory( 252 descriptor.output_type.file.pool 253 ) 254 return Method( 255 descriptor, 256 service, 257 ids.calculate(descriptor.name), 258 *_streaming_attributes(descriptor), 259 input_factory.GetPrototype(descriptor.input_type), 260 output_factory.GetPrototype(descriptor.output_type), 261 ) 262 263 class Type(enum.Enum): 264 UNARY = 0 265 SERVER_STREAMING = 1 266 CLIENT_STREAMING = 2 267 BIDIRECTIONAL_STREAMING = 3 268 269 def sentence_name(self) -> str: 270 return self.name.lower().replace( 271 '_', ' ' 272 ) # pylint: disable=no-member 273 274 @property 275 def name(self) -> str: 276 return self._descriptor.name 277 278 @property 279 def full_name(self) -> str: 280 return self._descriptor.full_name 281 282 @property 283 def package(self) -> str: 284 return self._descriptor.containing_service.file.package 285 286 @property 287 def type(self) -> 'Method.Type': 288 if self.server_streaming and self.client_streaming: 289 return self.Type.BIDIRECTIONAL_STREAMING 290 291 if self.server_streaming: 292 return self.Type.SERVER_STREAMING 293 294 if self.client_streaming: 295 return self.Type.CLIENT_STREAMING 296 297 return self.Type.UNARY 298 299 def get_request( 300 self, proto: Optional[Message], proto_kwargs: Optional[Dict[str, Any]] 301 ) -> Message: 302 """Returns a request_type protobuf message. 303 304 The client implementation may use this to support providing a request 305 as either a message object or as keyword arguments for the message's 306 fields (but not both). 307 """ 308 if proto_kwargs is None: 309 proto_kwargs = {} 310 311 if proto and proto_kwargs: 312 proto_str = repr(proto).strip() or "''" 313 raise TypeError( 314 'Requests must be provided either as a message object or a ' 315 'series of keyword args, but both were provided ' 316 f"({proto_str} and {proto_kwargs!r})" 317 ) 318 319 if proto is None: 320 return self.request_type(**proto_kwargs) 321 322 if not _message_is_type(proto, self.request_type): 323 try: 324 bad_type = proto.DESCRIPTOR.full_name 325 except AttributeError: 326 bad_type = type(proto).__name__ 327 328 raise TypeError( 329 f'Expected a message of type ' 330 f'{self.request_type.DESCRIPTOR.full_name}, ' 331 f'got {bad_type}' 332 ) 333 334 return proto 335 336 def request_parameters(self) -> Iterator[Parameter]: 337 """Yields inspect.Parameters corresponding to the request's fields. 338 339 This can be used to make function signatures match the request proto. 340 """ 341 for field in self.request_type.DESCRIPTOR.fields: 342 yield Parameter( 343 field.name, 344 Parameter.KEYWORD_ONLY, 345 annotation=_field_type_annotation(field), 346 default=field.default_value, 347 ) 348 349 def __repr__(self) -> str: 350 req = self._method_parameter(self.request_type, self.client_streaming) 351 res = self._method_parameter(self.response_type, self.server_streaming) 352 return f'<{self.full_name}({req}) returns ({res})>' 353 354 def _method_parameter(self, proto, streaming: bool) -> str: 355 """Returns a description of the method's request or response type.""" 356 stream = 'stream ' if streaming else '' 357 358 if proto.DESCRIPTOR.file.package == self.service.package: 359 return stream + proto.DESCRIPTOR.name 360 361 return stream + proto.DESCRIPTOR.full_name 362 363 def __str__(self) -> str: 364 return self.full_name 365 366 367T = TypeVar('T') 368 369 370def _name(item: Union[Service, Method]) -> str: 371 return item.full_name if isinstance(item, Service) else item.name 372 373 374class _AccessByName(Generic[T]): 375 """Wrapper for accessing types by name within a proto package structure.""" 376 377 def __init__(self, name: str, item: T): 378 setattr(self, name, item) 379 380 381class ServiceAccessor(Collection[T]): 382 """Navigates RPC services by name or ID.""" 383 384 def __init__(self, members, as_attrs: str = ''): 385 """Creates accessor from an {item: value} dict or [values] iterable.""" 386 # If the members arg was passed as a [values] iterable, convert it to 387 # an equivalent dictionary. 388 if not isinstance(members, dict): 389 members = {m: m for m in members} 390 391 by_name = {_name(k): v for k, v in members.items()} 392 self._by_id = {k.id: v for k, v in members.items()} 393 394 if as_attrs == 'members': 395 for name, member in by_name.items(): 396 setattr(self, name, member) 397 elif as_attrs == 'packages': 398 for package in python_protos.as_packages( 399 (m.package, _AccessByName(m.name, members[m])) for m in members 400 ).packages: 401 setattr(self, str(package), package) 402 elif as_attrs: 403 raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs') 404 405 def __getitem__(self, name_or_id: Union[str, int]): 406 """Accesses a service/method by the string name or ID.""" 407 try: 408 return self._by_id[_id(name_or_id)] 409 except KeyError: 410 pass 411 412 name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else '' 413 raise KeyError(f'Unknown ID {_id(name_or_id)}{name}') 414 415 def __iter__(self) -> Iterator[T]: 416 return iter(self._by_id.values()) 417 418 def __len__(self) -> int: 419 return len(self._by_id) 420 421 def __contains__(self, name_or_id) -> bool: 422 return _id(name_or_id) in self._by_id 423 424 def __repr__(self) -> str: 425 members = ', '.join(repr(m) for m in self._by_id.values()) 426 return f'{self.__class__.__name__}({members})' 427 428 429def _id(handle: Union[str, int]) -> int: 430 return ids.calculate(handle) if isinstance(handle, str) else handle 431 432 433class Methods(ServiceAccessor[Method]): 434 """A collection of Method descriptors in a Service.""" 435 436 def __init__(self, method: Iterable[Method]): 437 super().__init__(method) 438 439 440class Services(ServiceAccessor[Service]): 441 """A collection of Service descriptors.""" 442 443 def __init__(self, services: Iterable[Service]): 444 super().__init__(services) 445 446 447def get_method(service_accessor: ServiceAccessor, name: str): 448 """Returns a method matching the given full name in a ServiceAccessor. 449 450 Args: 451 name: name as package.Service/Method or package.Service.Method. 452 453 Raises: 454 ValueError: the method name is not properly formatted 455 KeyError: the method is not present 456 """ 457 if '/' in name: 458 service_name, method_name = name.split('/') 459 else: 460 service_name, method_name = name.rsplit('.', 1) 461 462 service = service_accessor[service_name] 463 if isinstance(service, Service): 464 service = service.methods 465 466 return service[method_name] 467