• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""This module defines data structures for protobuf entities."""
15
16import abc
17import collections
18import enum
19
20from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
21from typing import cast
22
23import google.protobuf.descriptor_pb2 as descriptor_pb2
24
25T = TypeVar('T')  # pylint: disable=invalid-name
26
27
28class ProtoNode(abc.ABC):
29    """A ProtoNode represents a C++ scope mapping of an entity in a .proto file.
30
31    Nodes form a tree beginning at a top-level (global) scope, descending into a
32    hierarchy of .proto packages and the messages and enums defined within them.
33    """
34    class Type(enum.Enum):
35        """The type of a ProtoNode.
36
37        PACKAGE maps to a C++ namespace.
38        MESSAGE maps to a C++ "Encoder" class within its own namespace.
39        ENUM maps to a C++ enum within its parent's namespace.
40        EXTERNAL represents a node defined within a different compilation unit.
41        SERVICE represents an RPC service definition.
42        """
43        PACKAGE = 1
44        MESSAGE = 2
45        ENUM = 3
46        EXTERNAL = 4
47        SERVICE = 5
48
49    def __init__(self, name: str):
50        self._name: str = name
51        self._children: Dict[str, 'ProtoNode'] = collections.OrderedDict()
52        self._parent: Optional['ProtoNode'] = None
53
54    @abc.abstractmethod
55    def type(self) -> 'ProtoNode.Type':
56        """The type of the node."""
57
58    def children(self) -> List['ProtoNode']:
59        return list(self._children.values())
60
61    def name(self) -> str:
62        return self._name
63
64    def cpp_name(self) -> str:
65        """The name of this node in generated C++ code."""
66        return self._name.replace('.', '::')
67
68    def cpp_namespace(self, root: Optional['ProtoNode'] = None) -> str:
69        """C++ namespace of the node, up to the specified root."""
70        return '::'.join(name for name in self._attr_hierarchy(
71            lambda node: node.cpp_name(), root) if name)
72
73    def proto_path(self) -> str:
74        """Fully-qualified package path of the node."""
75        path = '.'.join(self._attr_hierarchy(lambda node: node.name(), None))
76        return path.lstrip('.')
77
78    def nanopb_name(self) -> str:
79        """Full nanopb-style name of the node."""
80        name = '_'.join(self._attr_hierarchy(lambda node: node.name(), None))
81        return name.lstrip('_')
82
83    def common_ancestor(self, other: 'ProtoNode') -> Optional['ProtoNode']:
84        """Finds the earliest common ancestor of this node and other."""
85
86        if other is None:
87            return None
88
89        own_depth = self.depth()
90        other_depth = other.depth()
91        diff = abs(own_depth - other_depth)
92
93        if own_depth < other_depth:
94            first: Optional['ProtoNode'] = self
95            second: Optional['ProtoNode'] = other
96        else:
97            first = other
98            second = self
99
100        while diff > 0:
101            assert second is not None
102            second = second.parent()
103            diff -= 1
104
105        while first != second:
106            if first is None or second is None:
107                return None
108
109            first = first.parent()
110            second = second.parent()
111
112        return first
113
114    def depth(self) -> int:
115        """Returns the depth of this node from the root."""
116        depth = 0
117        node = self._parent
118        while node:
119            depth += 1
120            node = node.parent()
121        return depth
122
123    def add_child(self, child: 'ProtoNode') -> None:
124        """Inserts a new node into the tree as a child of this node.
125
126        Args:
127          child: The node to insert.
128
129        Raises:
130          ValueError: This node does not allow nesting the given type of child.
131        """
132        if not self._supports_child(child):
133            raise ValueError('Invalid child %s for node of type %s' %
134                             (child.type(), self.type()))
135
136        # pylint: disable=protected-access
137        if child._parent is not None:
138            del child._parent._children[child.name()]
139
140        child._parent = self
141        self._children[child.name()] = child
142        # pylint: enable=protected-access
143
144    def find(self, path: str) -> Optional['ProtoNode']:
145        """Finds a node within this node's subtree."""
146        node = self
147
148        # pylint: disable=protected-access
149        for section in path.split('.'):
150            child = node._children.get(section)
151            if child is None:
152                return None
153            node = child
154        # pylint: enable=protected-access
155
156        return node
157
158    def parent(self) -> Optional['ProtoNode']:
159        return self._parent
160
161    def __iter__(self) -> Iterator['ProtoNode']:
162        """Iterates depth-first through all nodes in this node's subtree."""
163        yield self
164        for child_iterator in self._children.values():
165            for child in child_iterator:
166                yield child
167
168    def _attr_hierarchy(self, attr_accessor: Callable[['ProtoNode'], T],
169                        root: Optional['ProtoNode']) -> Iterator[T]:
170        """Fetches node attributes at each level of the tree from the root.
171
172        Args:
173          attr_accessor: Function which extracts attributes from a ProtoNode.
174          root: The node at which to terminate.
175
176        Returns:
177          An iterator to a list of the selected attributes from the root to the
178          current node.
179        """
180        hierarchy = []
181        node: Optional['ProtoNode'] = self
182        while node is not None and node != root:
183            hierarchy.append(attr_accessor(node))
184            node = node.parent()
185        return reversed(hierarchy)
186
187    @abc.abstractmethod
188    def _supports_child(self, child: 'ProtoNode') -> bool:
189        """Returns True if child is a valid child type for the current node."""
190
191
192class ProtoPackage(ProtoNode):
193    """A protobuf package."""
194    def type(self) -> ProtoNode.Type:
195        return ProtoNode.Type.PACKAGE
196
197    def _supports_child(self, child: ProtoNode) -> bool:
198        return True
199
200
201class ProtoEnum(ProtoNode):
202    """Representation of an enum in a .proto file."""
203    def __init__(self, name: str):
204        super().__init__(name)
205        self._values: List[Tuple[str, int]] = []
206
207    def type(self) -> ProtoNode.Type:
208        return ProtoNode.Type.ENUM
209
210    def values(self) -> List[Tuple[str, int]]:
211        return list(self._values)
212
213    def add_value(self, name: str, value: int) -> None:
214        self._values.append((ProtoMessageField.upper_snake_case(name), value))
215
216    def _supports_child(self, child: ProtoNode) -> bool:
217        # Enums cannot have nested children.
218        return False
219
220
221class ProtoMessage(ProtoNode):
222    """Representation of a message in a .proto file."""
223    def __init__(self, name: str):
224        super().__init__(name)
225        self._fields: List['ProtoMessageField'] = []
226
227    def type(self) -> ProtoNode.Type:
228        return ProtoNode.Type.MESSAGE
229
230    def fields(self) -> List['ProtoMessageField']:
231        return list(self._fields)
232
233    def add_field(self, field: 'ProtoMessageField') -> None:
234        self._fields.append(field)
235
236    def _supports_child(self, child: ProtoNode) -> bool:
237        return (child.type() == self.Type.ENUM
238                or child.type() == self.Type.MESSAGE)
239
240
241class ProtoService(ProtoNode):
242    """Representation of a service in a .proto file."""
243    def __init__(self, name: str):
244        super().__init__(name)
245        self._methods: List['ProtoServiceMethod'] = []
246
247    def type(self) -> ProtoNode.Type:
248        return ProtoNode.Type.SERVICE
249
250    def methods(self) -> List['ProtoServiceMethod']:
251        return list(self._methods)
252
253    def add_method(self, method: 'ProtoServiceMethod') -> None:
254        self._methods.append(method)
255
256    def _supports_child(self, child: ProtoNode) -> bool:
257        return False
258
259
260class ProtoExternal(ProtoNode):
261    """A node from a different compilation unit.
262
263    An external node is one that isn't defined within the current compilation
264    unit, most likely as it comes from an imported proto file. Its type is not
265    known, so it does not have any members or additional data. Its purpose
266    within the node graph is to provide namespace resolution between compile
267    units.
268    """
269    def type(self) -> ProtoNode.Type:
270        return ProtoNode.Type.EXTERNAL
271
272    def _supports_child(self, child: ProtoNode) -> bool:
273        return True
274
275
276# This class is not a node and does not appear in the proto tree.
277# Fields belong to proto messages and are processed separately.
278class ProtoMessageField:
279    """Representation of a field within a protobuf message."""
280    def __init__(self,
281                 field_name: str,
282                 field_number: int,
283                 field_type: int,
284                 type_node: Optional[ProtoNode] = None,
285                 repeated: bool = False):
286        self._field_name = field_name
287        self._number: int = field_number
288        self._type: int = field_type
289        self._type_node: Optional[ProtoNode] = type_node
290        self._repeated: bool = repeated
291
292    def name(self) -> str:
293        return self.upper_camel_case(self._field_name)
294
295    def enum_name(self) -> str:
296        return self.upper_snake_case(self._field_name)
297
298    def number(self) -> int:
299        return self._number
300
301    def type(self) -> int:
302        return self._type
303
304    def type_node(self) -> Optional[ProtoNode]:
305        return self._type_node
306
307    def is_repeated(self) -> bool:
308        return self._repeated
309
310    @staticmethod
311    def upper_camel_case(field_name: str) -> str:
312        """Converts a field name to UpperCamelCase."""
313        name_components = field_name.split('_')
314        for i, _ in enumerate(name_components):
315            name_components[i] = name_components[i].lower().capitalize()
316        return ''.join(name_components)
317
318    @staticmethod
319    def upper_snake_case(field_name: str) -> str:
320        """Converts a field name to UPPER_SNAKE_CASE."""
321        return field_name.upper()
322
323
324class ProtoServiceMethod:
325    """A method defined in a protobuf service."""
326    class Type(enum.Enum):
327        UNARY = 'kUnary'
328        SERVER_STREAMING = 'kServerStreaming'
329        CLIENT_STREAMING = 'kClientStreaming'
330        BIDIRECTIONAL_STREAMING = 'kBidirectionalStreaming'
331
332        def cc_enum(self) -> str:
333            """Returns the pw_rpc MethodType C++ enum for this method type."""
334            return '::pw::rpc::internal::MethodType::' + self.value
335
336    def __init__(self, name: str, method_type: Type, request_type: ProtoNode,
337                 response_type: ProtoNode):
338        self._name = name
339        self._type = method_type
340        self._request_type = request_type
341        self._response_type = response_type
342
343    def name(self) -> str:
344        return self._name
345
346    def type(self) -> Type:
347        return self._type
348
349    def server_streaming(self) -> bool:
350        return (self._type is self.Type.SERVER_STREAMING
351                or self._type is self.Type.BIDIRECTIONAL_STREAMING)
352
353    def client_streaming(self) -> bool:
354        return (self._type is self.Type.CLIENT_STREAMING
355                or self._type is self.Type.BIDIRECTIONAL_STREAMING)
356
357    def request_type(self) -> ProtoNode:
358        return self._request_type
359
360    def response_type(self) -> ProtoNode:
361        return self._response_type
362
363
364def _add_enum_fields(enum_node: ProtoNode, proto_enum) -> None:
365    """Adds fields from a protobuf enum descriptor to an enum node."""
366    assert enum_node.type() == ProtoNode.Type.ENUM
367    enum_node = cast(ProtoEnum, enum_node)
368
369    for value in proto_enum.value:
370        enum_node.add_value(value.name, value.number)
371
372
373def _create_external_nodes(root: ProtoNode, path: str) -> ProtoNode:
374    """Creates external nodes for a path starting from the given root."""
375
376    node = root
377    for part in path.split('.'):
378        child = node.find(part)
379        if not child:
380            child = ProtoExternal(part)
381            node.add_child(child)
382        node = child
383
384    return node
385
386
387def _find_or_create_node(global_root: ProtoNode, package_root: ProtoNode,
388                         path: str) -> ProtoNode:
389    """Searches the proto tree for a node by path, creating it if not found."""
390
391    if path[0] == '.':
392        # Fully qualified path.
393        root_relative_path = path[1:]
394        search_root = global_root
395    else:
396        root_relative_path = path
397        search_root = package_root
398
399    node = search_root.find(root_relative_path)
400    if node is None:
401        # Create nodes for field types that don't exist within this
402        # compilation context, such as those imported from other .proto
403        # files.
404        node = _create_external_nodes(search_root, root_relative_path)
405
406    return node
407
408
409def _add_message_fields(global_root: ProtoNode, package_root: ProtoNode,
410                        message: ProtoNode, proto_message) -> None:
411    """Adds fields from a protobuf message descriptor to a message node."""
412    assert message.type() == ProtoNode.Type.MESSAGE
413    message = cast(ProtoMessage, message)
414
415    type_node: Optional[ProtoNode]
416
417    for field in proto_message.field:
418        if field.type_name:
419            # The "type_name" member contains the global .proto path of the
420            # field's type object, for example ".pw.protobuf.test.KeyValuePair".
421            # Try to find the node for this object within the current context.
422            type_node = _find_or_create_node(global_root, package_root,
423                                             field.type_name)
424        else:
425            type_node = None
426
427        repeated = \
428            field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
429        message.add_field(
430            ProtoMessageField(
431                field.name,
432                field.number,
433                field.type,
434                type_node,
435                repeated,
436            ))
437
438
439def _add_service_methods(global_root: ProtoNode, package_root: ProtoNode,
440                         service: ProtoNode, proto_service) -> None:
441    assert service.type() == ProtoNode.Type.SERVICE
442    service = cast(ProtoService, service)
443
444    for method in proto_service.method:
445        if method.client_streaming and method.server_streaming:
446            method_type = ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING
447        elif method.client_streaming:
448            method_type = ProtoServiceMethod.Type.CLIENT_STREAMING
449        elif method.server_streaming:
450            method_type = ProtoServiceMethod.Type.SERVER_STREAMING
451        else:
452            method_type = ProtoServiceMethod.Type.UNARY
453
454        request_node = _find_or_create_node(global_root, package_root,
455                                            method.input_type)
456        response_node = _find_or_create_node(global_root, package_root,
457                                             method.output_type)
458
459        service.add_method(
460            ProtoServiceMethod(method.name, method_type, request_node,
461                               response_node))
462
463
464def _populate_fields(proto_file, global_root: ProtoNode,
465                     package_root: ProtoNode) -> None:
466    """Traverses a proto file, adding all message and enum fields to a tree."""
467    def populate_message(node, message):
468        """Recursively populates nested messages and enums."""
469        _add_message_fields(global_root, package_root, node, message)
470
471        for proto_enum in message.enum_type:
472            _add_enum_fields(node.find(proto_enum.name), proto_enum)
473        for msg in message.nested_type:
474            populate_message(node.find(msg.name), msg)
475
476    # Iterate through the proto file, populating top-level objects.
477    for proto_enum in proto_file.enum_type:
478        enum_node = package_root.find(proto_enum.name)
479        assert enum_node is not None
480        _add_enum_fields(enum_node, proto_enum)
481
482    for message in proto_file.message_type:
483        populate_message(package_root.find(message.name), message)
484
485    for service in proto_file.service:
486        service_node = package_root.find(service.name)
487        assert service_node is not None
488        _add_service_methods(global_root, package_root, service_node, service)
489
490
491def _build_hierarchy(proto_file):
492    """Creates a ProtoNode hierarchy from a proto file descriptor."""
493
494    root = ProtoPackage('')
495    package_root = root
496
497    for part in proto_file.package.split('.'):
498        package = ProtoPackage(part)
499        package_root.add_child(package)
500        package_root = package
501
502    def build_message_subtree(proto_message):
503        node = ProtoMessage(proto_message.name)
504        for proto_enum in proto_message.enum_type:
505            node.add_child(ProtoEnum(proto_enum.name))
506        for submessage in proto_message.nested_type:
507            node.add_child(build_message_subtree(submessage))
508
509        return node
510
511    for proto_enum in proto_file.enum_type:
512        package_root.add_child(ProtoEnum(proto_enum.name))
513
514    for message in proto_file.message_type:
515        package_root.add_child(build_message_subtree(message))
516
517    for service in proto_file.service:
518        package_root.add_child(ProtoService(service.name))
519
520    return root, package_root
521
522
523def build_node_tree(file_descriptor_proto) -> Tuple[ProtoNode, ProtoNode]:
524    """Constructs a tree of proto nodes from a file descriptor.
525
526    Returns the root node of the entire proto package tree and the node
527    representing the file's package.
528    """
529    global_root, package_root = _build_hierarchy(file_descriptor_proto)
530    _populate_fields(file_descriptor_proto, global_root, package_root)
531    return global_root, package_root
532