1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Types representing the basic pw_rpc concepts: channel, service, method.""" 15 16from dataclasses import dataclass 17import enum 18from inspect import Parameter 19from typing import (Any, Callable, Collection, Dict, Generic, Iterable, 20 Iterator, Tuple, TypeVar, Union) 21 22from google.protobuf import descriptor_pb2, message_factory 23from google.protobuf.descriptor import (FieldDescriptor, MethodDescriptor, 24 ServiceDescriptor) 25from google.protobuf.message import Message 26from pw_protobuf_compiler import python_protos 27 28from pw_rpc import ids 29 30 31@dataclass(frozen=True) 32class Channel: 33 id: int 34 output: Callable[[bytes], Any] 35 36 def __repr__(self) -> str: 37 return f'Channel({self.id})' 38 39 40@dataclass(frozen=True, eq=False) 41class Service: 42 """Describes an RPC service.""" 43 _descriptor: ServiceDescriptor 44 id: int 45 methods: 'Methods' 46 47 @property 48 def name(self): 49 return self._descriptor.name 50 51 @property 52 def full_name(self): 53 return self._descriptor.full_name 54 55 @property 56 def package(self): 57 return self._descriptor.file.package 58 59 @classmethod 60 def from_descriptor(cls, descriptor: ServiceDescriptor) -> 'Service': 61 service = cls(descriptor, ids.calculate(descriptor.full_name), 62 None) # type: ignore[arg-type] 63 object.__setattr__( 64 service, 'methods', 65 Methods( 66 Method.from_descriptor(method_descriptor, service) 67 for method_descriptor in descriptor.methods)) 68 69 return service 70 71 def __repr__(self) -> str: 72 return f'Service({self.full_name!r})' 73 74 def __str__(self) -> str: 75 return self.full_name 76 77 78def _streaming_attributes(method) -> Tuple[bool, bool]: 79 # TODO(hepler): Investigate adding server_streaming and client_streaming 80 # attributes to the generated protobuf code. As a workaround, 81 # deserialize the FileDescriptorProto to get that information. 82 service = method.containing_service 83 84 file_pb = descriptor_pb2.FileDescriptorProto() 85 file_pb.MergeFromString(service.file.serialized_pb) 86 87 method_pb = file_pb.service[service.index].method[method.index] # pylint: disable=no-member 88 return method_pb.server_streaming, method_pb.client_streaming 89 90 91_PROTO_FIELD_TYPES = { 92 FieldDescriptor.TYPE_BOOL: bool, 93 FieldDescriptor.TYPE_BYTES: bytes, 94 FieldDescriptor.TYPE_DOUBLE: float, 95 FieldDescriptor.TYPE_ENUM: int, 96 FieldDescriptor.TYPE_FIXED32: int, 97 FieldDescriptor.TYPE_FIXED64: int, 98 FieldDescriptor.TYPE_FLOAT: float, 99 FieldDescriptor.TYPE_INT32: int, 100 FieldDescriptor.TYPE_INT64: int, 101 FieldDescriptor.TYPE_SFIXED32: int, 102 FieldDescriptor.TYPE_SFIXED64: int, 103 FieldDescriptor.TYPE_SINT32: int, 104 FieldDescriptor.TYPE_SINT64: int, 105 FieldDescriptor.TYPE_STRING: str, 106 FieldDescriptor.TYPE_UINT32: int, 107 FieldDescriptor.TYPE_UINT64: int, 108 # These types are not annotated: 109 # FieldDescriptor.TYPE_GROUP = 10 110 # FieldDescriptor.TYPE_MESSAGE = 11 111} 112 113 114def _field_type_annotation(field: FieldDescriptor): 115 """Creates a field type annotation to use in the help message only.""" 116 if field.type == FieldDescriptor.TYPE_MESSAGE: 117 annotation = message_factory.MessageFactory( 118 field.message_type.file.pool).GetPrototype(field.message_type) 119 else: 120 annotation = _PROTO_FIELD_TYPES.get(field.type, Parameter.empty) 121 122 if field.label == FieldDescriptor.LABEL_REPEATED: 123 return Iterable[annotation] # type: ignore[valid-type] 124 125 return annotation 126 127 128def field_help(proto_message, *, annotations: bool = False) -> Iterator[str]: 129 """Yields argument strings for proto fields for use in a help message.""" 130 for field in proto_message.DESCRIPTOR.fields: 131 if field.type == FieldDescriptor.TYPE_ENUM: 132 value = field.enum_type.values_by_number[field.default_value].name 133 type_name = field.enum_type.full_name 134 value = f'{type_name.rsplit(".", 1)[0]}.{value}' 135 else: 136 type_name = _PROTO_FIELD_TYPES[field.type].__name__ 137 value = repr(field.default_value) 138 139 if annotations: 140 yield f'{field.name}: {type_name} = {value}' 141 else: 142 yield f'{field.name}={value}' 143 144 145def _message_is_type(proto, expected_type) -> bool: 146 """Returns true if the protobuf instance is the expected type.""" 147 # Getting protobuf classes from google.protobuf.message_factory may create a 148 # new, unique generated proto class. Any generated classes for a particular 149 # proto message share the same MessageDescriptor instance and are 150 # interchangeable, so check the descriptors in addition to the types. 151 return isinstance(proto, expected_type) or (isinstance( 152 proto, Message) and proto.DESCRIPTOR is expected_type.DESCRIPTOR) 153 154 155@dataclass(frozen=True, eq=False) 156class Method: 157 """Describes a method in a service.""" 158 159 _descriptor: MethodDescriptor 160 service: Service 161 id: int 162 server_streaming: bool 163 client_streaming: bool 164 request_type: Any 165 response_type: Any 166 167 @classmethod 168 def from_descriptor(cls, descriptor: MethodDescriptor, service: Service): 169 input_factory = message_factory.MessageFactory( 170 descriptor.input_type.file.pool) 171 output_factory = message_factory.MessageFactory( 172 descriptor.output_type.file.pool) 173 return Method( 174 descriptor, 175 service, 176 ids.calculate(descriptor.name), 177 *_streaming_attributes(descriptor), 178 input_factory.GetPrototype(descriptor.input_type), 179 output_factory.GetPrototype(descriptor.output_type), 180 ) 181 182 class Type(enum.Enum): 183 UNARY = 0 184 SERVER_STREAMING = 1 185 CLIENT_STREAMING = 2 186 BIDIRECTIONAL_STREAMING = 3 187 188 def sentence_name(self) -> str: 189 return self.name.lower().replace('_', ' ') # pylint: disable=no-member 190 191 @property 192 def name(self) -> str: 193 return self._descriptor.name 194 195 @property 196 def full_name(self) -> str: 197 return self._descriptor.full_name 198 199 @property 200 def package(self) -> str: 201 return self._descriptor.containing_service.file.package 202 203 @property 204 def type(self) -> 'Method.Type': 205 if self.server_streaming and self.client_streaming: 206 return self.Type.BIDIRECTIONAL_STREAMING 207 208 if self.server_streaming: 209 return self.Type.SERVER_STREAMING 210 211 if self.client_streaming: 212 return self.Type.CLIENT_STREAMING 213 214 return self.Type.UNARY 215 216 def get_request(self, proto, proto_kwargs: Dict[str, Any]): 217 """Returns a request_type protobuf message. 218 219 The client implementation may use this to support providing a request 220 as either a message object or as keyword arguments for the message's 221 fields (but not both). 222 """ 223 if proto and proto_kwargs: 224 proto_str = repr(proto).strip() or "''" 225 raise TypeError( 226 'Requests must be provided either as a message object or a ' 227 'series of keyword args, but both were provided ' 228 f"({proto_str} and {proto_kwargs!r})") 229 230 if proto is None: 231 return self.request_type(**proto_kwargs) 232 233 if not _message_is_type(proto, self.request_type): 234 try: 235 bad_type = proto.DESCRIPTOR.full_name 236 except AttributeError: 237 bad_type = type(proto).__name__ 238 239 raise TypeError(f'Expected a message of type ' 240 f'{self.request_type.DESCRIPTOR.full_name}, ' 241 f'got {bad_type}') 242 243 return proto 244 245 def request_parameters(self) -> Iterator[Parameter]: 246 """Yields inspect.Parameters corresponding to the request's fields. 247 248 This can be used to make function signatures match the request proto. 249 """ 250 for field in self.request_type.DESCRIPTOR.fields: 251 yield Parameter(field.name, 252 Parameter.KEYWORD_ONLY, 253 annotation=_field_type_annotation(field), 254 default=field.default_value) 255 256 def __repr__(self) -> str: 257 req = self._method_parameter(self.request_type, self.client_streaming) 258 res = self._method_parameter(self.response_type, self.server_streaming) 259 return f'<{self.full_name}({req}) returns ({res})>' 260 261 def _method_parameter(self, proto, streaming: bool) -> str: 262 """Returns a description of the method's request or response type.""" 263 stream = 'stream ' if streaming else '' 264 265 if proto.DESCRIPTOR.file.package == self.service.package: 266 return stream + proto.DESCRIPTOR.name 267 268 return stream + proto.DESCRIPTOR.full_name 269 270 def __str__(self) -> str: 271 return self.full_name 272 273 274T = TypeVar('T') 275 276 277def _name(item: Union[Service, Method]) -> str: 278 return item.full_name if isinstance(item, Service) else item.name 279 280 281class _AccessByName(Generic[T]): 282 """Wrapper for accessing types by name within a proto package structure.""" 283 def __init__(self, name: str, item: T): 284 setattr(self, name, item) 285 286 287class ServiceAccessor(Collection[T]): 288 """Navigates RPC services by name or ID.""" 289 def __init__(self, members, as_attrs: str = ''): 290 """Creates accessor from an {item: value} dict or [values] iterable.""" 291 # If the members arg was passed as a [values] iterable, convert it to 292 # an equivalent dictionary. 293 if not isinstance(members, dict): 294 members = {m: m for m in members} 295 296 by_name = {_name(k): v for k, v in members.items()} 297 self._by_id = {k.id: v for k, v in members.items()} 298 299 if as_attrs == 'members': 300 for name, member in by_name.items(): 301 setattr(self, name, member) 302 elif as_attrs == 'packages': 303 for package in python_protos.as_packages( 304 (m.package, _AccessByName(m.name, members[m])) 305 for m in members).packages: 306 setattr(self, str(package), package) 307 elif as_attrs: 308 raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs') 309 310 def __getitem__(self, name_or_id: Union[str, int]): 311 """Accesses a service/method by the string name or ID.""" 312 try: 313 return self._by_id[_id(name_or_id)] 314 except KeyError: 315 pass 316 317 name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else '' 318 raise KeyError(f'Unknown ID {_id(name_or_id)}{name}') 319 320 def __iter__(self) -> Iterator[T]: 321 return iter(self._by_id.values()) 322 323 def __len__(self) -> int: 324 return len(self._by_id) 325 326 def __contains__(self, name_or_id) -> bool: 327 return _id(name_or_id) in self._by_id 328 329 def __repr__(self) -> str: 330 members = ', '.join(repr(m) for m in self._by_id.values()) 331 return f'{self.__class__.__name__}({members})' 332 333 334def _id(handle: Union[str, int]) -> int: 335 return ids.calculate(handle) if isinstance(handle, str) else handle 336 337 338class Methods(ServiceAccessor[Method]): 339 """A collection of Method descriptors in a Service.""" 340 def __init__(self, method: Iterable[Method]): 341 super().__init__(method) 342 343 344class Services(ServiceAccessor[Service]): 345 """A collection of Service descriptors.""" 346 def __init__(self, services: Iterable[Service]): 347 super().__init__(services) 348 349 350def get_method(service_accessor: ServiceAccessor, name: str): 351 """Returns a method matching the given full name in a ServiceAccessor. 352 353 Args: 354 name: name as package.Service/Method or package.Service.Method. 355 356 Raises: 357 ValueError: the method name is not properly formatted 358 KeyError: the method is not present 359 """ 360 if '/' in name: 361 service_name, method_name = name.split('/') 362 else: 363 service_name, method_name = name.rsplit('.', 1) 364 365 service = service_accessor[service_name] 366 if isinstance(service, Service): 367 service = service.methods 368 369 return service[method_name] 370