1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""This module defines data structures for protobuf entities.""" 15 16import abc 17import collections 18import enum 19 20from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar 21from typing import cast 22 23import google.protobuf.descriptor_pb2 as descriptor_pb2 24 25T = TypeVar('T') # pylint: disable=invalid-name 26 27 28class ProtoNode(abc.ABC): 29 """A ProtoNode represents a C++ scope mapping of an entity in a .proto file. 30 31 Nodes form a tree beginning at a top-level (global) scope, descending into a 32 hierarchy of .proto packages and the messages and enums defined within them. 33 """ 34 class Type(enum.Enum): 35 """The type of a ProtoNode. 36 37 PACKAGE maps to a C++ namespace. 38 MESSAGE maps to a C++ "Encoder" class within its own namespace. 39 ENUM maps to a C++ enum within its parent's namespace. 40 EXTERNAL represents a node defined within a different compilation unit. 41 SERVICE represents an RPC service definition. 42 """ 43 PACKAGE = 1 44 MESSAGE = 2 45 ENUM = 3 46 EXTERNAL = 4 47 SERVICE = 5 48 49 def __init__(self, name: str): 50 self._name: str = name 51 self._children: Dict[str, 'ProtoNode'] = collections.OrderedDict() 52 self._parent: Optional['ProtoNode'] = None 53 54 @abc.abstractmethod 55 def type(self) -> 'ProtoNode.Type': 56 """The type of the node.""" 57 58 def children(self) -> List['ProtoNode']: 59 return list(self._children.values()) 60 61 def name(self) -> str: 62 return self._name 63 64 def cpp_name(self) -> str: 65 """The name of this node in generated C++ code.""" 66 return self._name.replace('.', '::') 67 68 def cpp_namespace(self, root: Optional['ProtoNode'] = None) -> str: 69 """C++ namespace of the node, up to the specified root.""" 70 return '::'.join(name for name in self._attr_hierarchy( 71 lambda node: node.cpp_name(), root) if name) 72 73 def proto_path(self) -> str: 74 """Fully-qualified package path of the node.""" 75 path = '.'.join(self._attr_hierarchy(lambda node: node.name(), None)) 76 return path.lstrip('.') 77 78 def nanopb_name(self) -> str: 79 """Full nanopb-style name of the node.""" 80 name = '_'.join(self._attr_hierarchy(lambda node: node.name(), None)) 81 return name.lstrip('_') 82 83 def common_ancestor(self, other: 'ProtoNode') -> Optional['ProtoNode']: 84 """Finds the earliest common ancestor of this node and other.""" 85 86 if other is None: 87 return None 88 89 own_depth = self.depth() 90 other_depth = other.depth() 91 diff = abs(own_depth - other_depth) 92 93 if own_depth < other_depth: 94 first: Optional['ProtoNode'] = self 95 second: Optional['ProtoNode'] = other 96 else: 97 first = other 98 second = self 99 100 while diff > 0: 101 assert second is not None 102 second = second.parent() 103 diff -= 1 104 105 while first != second: 106 if first is None or second is None: 107 return None 108 109 first = first.parent() 110 second = second.parent() 111 112 return first 113 114 def depth(self) -> int: 115 """Returns the depth of this node from the root.""" 116 depth = 0 117 node = self._parent 118 while node: 119 depth += 1 120 node = node.parent() 121 return depth 122 123 def add_child(self, child: 'ProtoNode') -> None: 124 """Inserts a new node into the tree as a child of this node. 125 126 Args: 127 child: The node to insert. 128 129 Raises: 130 ValueError: This node does not allow nesting the given type of child. 131 """ 132 if not self._supports_child(child): 133 raise ValueError('Invalid child %s for node of type %s' % 134 (child.type(), self.type())) 135 136 # pylint: disable=protected-access 137 if child._parent is not None: 138 del child._parent._children[child.name()] 139 140 child._parent = self 141 self._children[child.name()] = child 142 # pylint: enable=protected-access 143 144 def find(self, path: str) -> Optional['ProtoNode']: 145 """Finds a node within this node's subtree.""" 146 node = self 147 148 # pylint: disable=protected-access 149 for section in path.split('.'): 150 child = node._children.get(section) 151 if child is None: 152 return None 153 node = child 154 # pylint: enable=protected-access 155 156 return node 157 158 def parent(self) -> Optional['ProtoNode']: 159 return self._parent 160 161 def __iter__(self) -> Iterator['ProtoNode']: 162 """Iterates depth-first through all nodes in this node's subtree.""" 163 yield self 164 for child_iterator in self._children.values(): 165 for child in child_iterator: 166 yield child 167 168 def _attr_hierarchy(self, attr_accessor: Callable[['ProtoNode'], T], 169 root: Optional['ProtoNode']) -> Iterator[T]: 170 """Fetches node attributes at each level of the tree from the root. 171 172 Args: 173 attr_accessor: Function which extracts attributes from a ProtoNode. 174 root: The node at which to terminate. 175 176 Returns: 177 An iterator to a list of the selected attributes from the root to the 178 current node. 179 """ 180 hierarchy = [] 181 node: Optional['ProtoNode'] = self 182 while node is not None and node != root: 183 hierarchy.append(attr_accessor(node)) 184 node = node.parent() 185 return reversed(hierarchy) 186 187 @abc.abstractmethod 188 def _supports_child(self, child: 'ProtoNode') -> bool: 189 """Returns True if child is a valid child type for the current node.""" 190 191 192class ProtoPackage(ProtoNode): 193 """A protobuf package.""" 194 def type(self) -> ProtoNode.Type: 195 return ProtoNode.Type.PACKAGE 196 197 def _supports_child(self, child: ProtoNode) -> bool: 198 return True 199 200 201class ProtoEnum(ProtoNode): 202 """Representation of an enum in a .proto file.""" 203 def __init__(self, name: str): 204 super().__init__(name) 205 self._values: List[Tuple[str, int]] = [] 206 207 def type(self) -> ProtoNode.Type: 208 return ProtoNode.Type.ENUM 209 210 def values(self) -> List[Tuple[str, int]]: 211 return list(self._values) 212 213 def add_value(self, name: str, value: int) -> None: 214 self._values.append((ProtoMessageField.upper_snake_case(name), value)) 215 216 def _supports_child(self, child: ProtoNode) -> bool: 217 # Enums cannot have nested children. 218 return False 219 220 221class ProtoMessage(ProtoNode): 222 """Representation of a message in a .proto file.""" 223 def __init__(self, name: str): 224 super().__init__(name) 225 self._fields: List['ProtoMessageField'] = [] 226 227 def type(self) -> ProtoNode.Type: 228 return ProtoNode.Type.MESSAGE 229 230 def fields(self) -> List['ProtoMessageField']: 231 return list(self._fields) 232 233 def add_field(self, field: 'ProtoMessageField') -> None: 234 self._fields.append(field) 235 236 def _supports_child(self, child: ProtoNode) -> bool: 237 return (child.type() == self.Type.ENUM 238 or child.type() == self.Type.MESSAGE) 239 240 241class ProtoService(ProtoNode): 242 """Representation of a service in a .proto file.""" 243 def __init__(self, name: str): 244 super().__init__(name) 245 self._methods: List['ProtoServiceMethod'] = [] 246 247 def type(self) -> ProtoNode.Type: 248 return ProtoNode.Type.SERVICE 249 250 def methods(self) -> List['ProtoServiceMethod']: 251 return list(self._methods) 252 253 def add_method(self, method: 'ProtoServiceMethod') -> None: 254 self._methods.append(method) 255 256 def _supports_child(self, child: ProtoNode) -> bool: 257 return False 258 259 260class ProtoExternal(ProtoNode): 261 """A node from a different compilation unit. 262 263 An external node is one that isn't defined within the current compilation 264 unit, most likely as it comes from an imported proto file. Its type is not 265 known, so it does not have any members or additional data. Its purpose 266 within the node graph is to provide namespace resolution between compile 267 units. 268 """ 269 def type(self) -> ProtoNode.Type: 270 return ProtoNode.Type.EXTERNAL 271 272 def _supports_child(self, child: ProtoNode) -> bool: 273 return True 274 275 276# This class is not a node and does not appear in the proto tree. 277# Fields belong to proto messages and are processed separately. 278class ProtoMessageField: 279 """Representation of a field within a protobuf message.""" 280 def __init__(self, 281 field_name: str, 282 field_number: int, 283 field_type: int, 284 type_node: Optional[ProtoNode] = None, 285 repeated: bool = False): 286 self._field_name = field_name 287 self._number: int = field_number 288 self._type: int = field_type 289 self._type_node: Optional[ProtoNode] = type_node 290 self._repeated: bool = repeated 291 292 def name(self) -> str: 293 return self.upper_camel_case(self._field_name) 294 295 def enum_name(self) -> str: 296 return self.upper_snake_case(self._field_name) 297 298 def number(self) -> int: 299 return self._number 300 301 def type(self) -> int: 302 return self._type 303 304 def type_node(self) -> Optional[ProtoNode]: 305 return self._type_node 306 307 def is_repeated(self) -> bool: 308 return self._repeated 309 310 @staticmethod 311 def upper_camel_case(field_name: str) -> str: 312 """Converts a field name to UpperCamelCase.""" 313 name_components = field_name.split('_') 314 for i, _ in enumerate(name_components): 315 name_components[i] = name_components[i].lower().capitalize() 316 return ''.join(name_components) 317 318 @staticmethod 319 def upper_snake_case(field_name: str) -> str: 320 """Converts a field name to UPPER_SNAKE_CASE.""" 321 return field_name.upper() 322 323 324class ProtoServiceMethod: 325 """A method defined in a protobuf service.""" 326 class Type(enum.Enum): 327 UNARY = 'kUnary' 328 SERVER_STREAMING = 'kServerStreaming' 329 CLIENT_STREAMING = 'kClientStreaming' 330 BIDIRECTIONAL_STREAMING = 'kBidirectionalStreaming' 331 332 def cc_enum(self) -> str: 333 """Returns the pw_rpc MethodType C++ enum for this method type.""" 334 return '::pw::rpc::internal::MethodType::' + self.value 335 336 def __init__(self, name: str, method_type: Type, request_type: ProtoNode, 337 response_type: ProtoNode): 338 self._name = name 339 self._type = method_type 340 self._request_type = request_type 341 self._response_type = response_type 342 343 def name(self) -> str: 344 return self._name 345 346 def type(self) -> Type: 347 return self._type 348 349 def server_streaming(self) -> bool: 350 return (self._type is self.Type.SERVER_STREAMING 351 or self._type is self.Type.BIDIRECTIONAL_STREAMING) 352 353 def client_streaming(self) -> bool: 354 return (self._type is self.Type.CLIENT_STREAMING 355 or self._type is self.Type.BIDIRECTIONAL_STREAMING) 356 357 def request_type(self) -> ProtoNode: 358 return self._request_type 359 360 def response_type(self) -> ProtoNode: 361 return self._response_type 362 363 364def _add_enum_fields(enum_node: ProtoNode, proto_enum) -> None: 365 """Adds fields from a protobuf enum descriptor to an enum node.""" 366 assert enum_node.type() == ProtoNode.Type.ENUM 367 enum_node = cast(ProtoEnum, enum_node) 368 369 for value in proto_enum.value: 370 enum_node.add_value(value.name, value.number) 371 372 373def _create_external_nodes(root: ProtoNode, path: str) -> ProtoNode: 374 """Creates external nodes for a path starting from the given root.""" 375 376 node = root 377 for part in path.split('.'): 378 child = node.find(part) 379 if not child: 380 child = ProtoExternal(part) 381 node.add_child(child) 382 node = child 383 384 return node 385 386 387def _find_or_create_node(global_root: ProtoNode, package_root: ProtoNode, 388 path: str) -> ProtoNode: 389 """Searches the proto tree for a node by path, creating it if not found.""" 390 391 if path[0] == '.': 392 # Fully qualified path. 393 root_relative_path = path[1:] 394 search_root = global_root 395 else: 396 root_relative_path = path 397 search_root = package_root 398 399 node = search_root.find(root_relative_path) 400 if node is None: 401 # Create nodes for field types that don't exist within this 402 # compilation context, such as those imported from other .proto 403 # files. 404 node = _create_external_nodes(search_root, root_relative_path) 405 406 return node 407 408 409def _add_message_fields(global_root: ProtoNode, package_root: ProtoNode, 410 message: ProtoNode, proto_message) -> None: 411 """Adds fields from a protobuf message descriptor to a message node.""" 412 assert message.type() == ProtoNode.Type.MESSAGE 413 message = cast(ProtoMessage, message) 414 415 type_node: Optional[ProtoNode] 416 417 for field in proto_message.field: 418 if field.type_name: 419 # The "type_name" member contains the global .proto path of the 420 # field's type object, for example ".pw.protobuf.test.KeyValuePair". 421 # Try to find the node for this object within the current context. 422 type_node = _find_or_create_node(global_root, package_root, 423 field.type_name) 424 else: 425 type_node = None 426 427 repeated = \ 428 field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 429 message.add_field( 430 ProtoMessageField( 431 field.name, 432 field.number, 433 field.type, 434 type_node, 435 repeated, 436 )) 437 438 439def _add_service_methods(global_root: ProtoNode, package_root: ProtoNode, 440 service: ProtoNode, proto_service) -> None: 441 assert service.type() == ProtoNode.Type.SERVICE 442 service = cast(ProtoService, service) 443 444 for method in proto_service.method: 445 if method.client_streaming and method.server_streaming: 446 method_type = ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING 447 elif method.client_streaming: 448 method_type = ProtoServiceMethod.Type.CLIENT_STREAMING 449 elif method.server_streaming: 450 method_type = ProtoServiceMethod.Type.SERVER_STREAMING 451 else: 452 method_type = ProtoServiceMethod.Type.UNARY 453 454 request_node = _find_or_create_node(global_root, package_root, 455 method.input_type) 456 response_node = _find_or_create_node(global_root, package_root, 457 method.output_type) 458 459 service.add_method( 460 ProtoServiceMethod(method.name, method_type, request_node, 461 response_node)) 462 463 464def _populate_fields(proto_file, global_root: ProtoNode, 465 package_root: ProtoNode) -> None: 466 """Traverses a proto file, adding all message and enum fields to a tree.""" 467 def populate_message(node, message): 468 """Recursively populates nested messages and enums.""" 469 _add_message_fields(global_root, package_root, node, message) 470 471 for proto_enum in message.enum_type: 472 _add_enum_fields(node.find(proto_enum.name), proto_enum) 473 for msg in message.nested_type: 474 populate_message(node.find(msg.name), msg) 475 476 # Iterate through the proto file, populating top-level objects. 477 for proto_enum in proto_file.enum_type: 478 enum_node = package_root.find(proto_enum.name) 479 assert enum_node is not None 480 _add_enum_fields(enum_node, proto_enum) 481 482 for message in proto_file.message_type: 483 populate_message(package_root.find(message.name), message) 484 485 for service in proto_file.service: 486 service_node = package_root.find(service.name) 487 assert service_node is not None 488 _add_service_methods(global_root, package_root, service_node, service) 489 490 491def _build_hierarchy(proto_file): 492 """Creates a ProtoNode hierarchy from a proto file descriptor.""" 493 494 root = ProtoPackage('') 495 package_root = root 496 497 for part in proto_file.package.split('.'): 498 package = ProtoPackage(part) 499 package_root.add_child(package) 500 package_root = package 501 502 def build_message_subtree(proto_message): 503 node = ProtoMessage(proto_message.name) 504 for proto_enum in proto_message.enum_type: 505 node.add_child(ProtoEnum(proto_enum.name)) 506 for submessage in proto_message.nested_type: 507 node.add_child(build_message_subtree(submessage)) 508 509 return node 510 511 for proto_enum in proto_file.enum_type: 512 package_root.add_child(ProtoEnum(proto_enum.name)) 513 514 for message in proto_file.message_type: 515 package_root.add_child(build_message_subtree(message)) 516 517 for service in proto_file.service: 518 package_root.add_child(ProtoService(service.name)) 519 520 return root, package_root 521 522 523def build_node_tree(file_descriptor_proto) -> Tuple[ProtoNode, ProtoNode]: 524 """Constructs a tree of proto nodes from a file descriptor. 525 526 Returns the root node of the entire proto package tree and the node 527 representing the file's package. 528 """ 529 global_root, package_root = _build_hierarchy(file_descriptor_proto) 530 _populate_fields(file_descriptor_proto, global_root, package_root) 531 return global_root, package_root 532