• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# python3
2# Copyright 2020 The ChromiumOS Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Proto-related helper functions."""
6
7import importlib
8import os
9
10from typing import Any, Dict, List, Set, NamedTuple
11
12from google.protobuf import message as pb_message
13from google.protobuf import symbol_database
14from google.protobuf.descriptor import FieldDescriptor
15
16from chromiumos.config.public_replication import public_replication_pb2
17
18
19def _build_type_map():
20  """Build a map from the TYPE_* constants in protobuf to string names."""
21  mapping = {}
22  for field in dir(FieldDescriptor):
23    if field.startswith("TYPE_"):
24      mapping[getattr(FieldDescriptor, field)] = field[5:].lower()
25  return mapping
26
27
28type_strings = _build_type_map()
29
30
31class FieldInfo(NamedTuple):
32  """Named tuple to represent information about a proto field.
33
34  Fields:
35    name:     Name of the field in the message
36    typeid:   FieldDescriptor type for the field
37    typename: For compound fields, the name of the message type.
38              For value fields, a string-ified type name
39    repeated: True if the field is repeated
40  """
41  name: str
42  typeid: int
43  typename: str
44  repeated: bool
45
46
47def create_symbol_db() -> symbol_database.SymbolDatabase():
48  """Load any generated messages from python/ and return symbol database.
49
50  Messages auto-register when imported, so we just recursively import and
51  generate protobuffer files and return the default instance of SymbolDatabase.
52
53  Returns:
54      symbol_database.Default()
55  """
56
57  def __import_modules(dirname: str, paths: [str]):
58    """Recurse through a list of paths and automatically load protobufs.
59
60      This starts with dirname and recursively descends looking for python files
61      ending with _pb2.py and loads them into the current name space.
62
63      Args:
64        dirname: Directory name to process
65        paths: list of parent paths names for current call
66      """
67    for path in os.listdir(dirname):
68      full_path = os.path.join(dirname, path)
69
70      if os.path.isdir(full_path) and not os.path.islink(full_path):
71        __import_modules(full_path, paths + [path])
72      elif os.path.isfile(full_path):
73        if path.endswith("_pb2.py"):
74          importlib.import_module('%s.%s' % ('.'.join(paths), path[:-3]))
75
76  __import_modules(
77      os.path.join(os.path.dirname(__file__), "../../python/chromiumos"),
78      ["chromiumos"],
79  )
80  return symbol_database.Default()
81
82
83def resolve_field_path(
84    message: pb_message.Message,
85    path: str,
86) -> List[FieldInfo]:
87  """Resolve a dotted field path into specific information about the fields.
88
89  A field path is of the format foo.bar.baz where the dotted notation .field
90  indicates a particular field of the preceding message type.
91
92  We resolve this by iteratively looking up each field starting with a root
93  message type, and populating information about it.
94
95  The result is a list of FieldInfo instances, one per field.  If a field
96  is not found, then None is returned for it and subsequent fields.
97
98  Args:
99    message: a protobuffer Message type to root the path in
100    path: a dotted field path
101  """
102
103  fields = path.split('.')
104  infos = [None] * len(fields)
105
106  current = message.DESCRIPTOR
107  for idx, field in enumerate(fields):
108    descriptor = current.fields_by_name.get(field)
109    if not descriptor:
110      break
111
112    typename = type_strings[descriptor.type]
113    if descriptor.message_type:
114      typename = descriptor.message_type.name
115
116    infos[idx] = FieldInfo(
117        name=field,
118        typeid=descriptor.type,
119        typename=typename,
120        repeated=(descriptor.label == descriptor.LABEL_REPEATED),
121    )
122
123    # If field isn't a message type then we're done
124    if descriptor.type != descriptor.TYPE_MESSAGE:
125      break
126
127    current = descriptor.message_type
128
129  return infos
130
131
132def get_all_fields(message: pb_message.Message) -> List[Any]:
133  """Returns the value of each field in message.
134
135  Note that the value, not the name, is returned. For example, given a Timestamp
136  message:
137
138  {
139    "seconds": 1,
140    "nanos": 2
141  }
142
143  the result is [1, 2].
144  """
145  fields = []
146  for field_descriptor in message.DESCRIPTOR.fields:
147    fields.append(getattr(message, field_descriptor.name))
148  return fields
149
150
151def get_dep_graph(message: pb_message.Message) -> Dict[str, List[str]]:
152  """Compute the dep graph of a message.
153
154  This is a sparse representation of a DAG as a dict where the key is the
155  name of a message, and the value is the list of that message's
156  dependencies.
157  """
158
159  def helper(dep_graph, descriptor):
160    deps = set()
161    for field in descriptor.fields:
162      if field.type == field.TYPE_MESSAGE:
163        deps.add(field.message_type.full_name)
164        helper(dep_graph, field.message_type)
165    dep_graph[descriptor.full_name] = sorted(deps)
166
167  graph = {}
168  helper(graph, message.DESCRIPTOR)
169  return graph
170
171
172def get_dep_order(message: pb_message.Message) -> List[str]:
173  """Compute dependency order of protobuf type and its dependencies.
174
175  This is a list from a preorder traversal of the dependency graph above.
176  """
177
178  def dfs(graph, callback, node, seen=None):
179    if seen is None:
180      seen = set()
181
182    if node in graph:
183      for child in graph[node]:
184        dfs(graph, callback, child, seen)
185
186    if not node in seen:
187      callback(node)
188    seen.add(node)
189
190  deps = []
191  dfs(get_dep_graph(message), deps.append, message.DESCRIPTOR.full_name)
192  return deps
193
194
195def apply_public_replication(src: pb_message.Message, dst: pb_message.Message):
196  """Traverses src and merges fields to dst when a PublicReplication message is
197  found.
198
199  See the comment on the PublicReplication proto for a complete description of
200  the semantics of PublicReplication.
201
202  Args:
203    src: Source message.
204    dst: Destination message to be merged into.
205  """
206  __apply_public_replication_internal(src, dst, visited_messages=set())
207
208
209def __apply_public_replication_internal(src: pb_message.Message,
210                                        dst: pb_message.Message,
211                                        visited_messages=Set[str]):
212  """Private function to do most of the work of apply_public_replication.
213
214  Allows bookkeeping information to be passed between recursive calls, along
215  with the src and dst messages.
216
217  Args:
218    src: Source message.
219    dst: Destination message to be merged into.
220    visited_messages: Set of the full message names that have been visited
221      higher in the call stack. Used to prevent infinite recursion if a message
222      references itself.
223  """
224  if src.DESCRIPTOR.full_name != dst.DESCRIPTOR.full_name:
225    raise ValueError(
226        'src and dst must be the same message type. Got '
227        f'{src.DESCRIPTOR.full_name} and {dst.DESCRIPTOR.full_name}')
228
229  # If a message with the same type as src has been visited higher in the call
230  # stack, stop here. Then, create a new set with src's name in it. Note that
231  # it doesn't work to add src's name to visited_messages, because a message
232  # with the same type as src might be visited after this function exists, e.g.
233  # in a list of messages with the same type.
234  if src.DESCRIPTOR.full_name in visited_messages:
235    return
236
237  visited_messages = visited_messages.union([src.DESCRIPTOR.full_name])
238
239  # First, see if any of the fields on src are a PublicReplication message. If
240  # one is found, apply the field mask within it.
241  for field_descriptor in src.DESCRIPTOR.fields:
242    message_descriptor = field_descriptor.message_type
243    if (message_descriptor and message_descriptor.full_name
244        == public_replication_pb2.PublicReplication.DESCRIPTOR.full_name):
245      public_fields = getattr(src, field_descriptor.name).public_fields
246      public_fields.MergeMessage(src, dst)
247
248  # Iterate the fields of src and call __apply_public_replication_internal
249  # recursively.
250  for field_descriptor in src.DESCRIPTOR.fields:
251    if field_descriptor.type != field_descriptor.TYPE_MESSAGE:
252      continue
253
254    if field_descriptor.label == field_descriptor.LABEL_REPEATED:
255      # For repeated fields, for each message in src, create a new message in
256      # dst and call __apply_public_replication_internal.
257      for next_src in getattr(src, field_descriptor.name):
258        # map fields are considered repeated messages, but do not have an 'add'
259        # method. Skip this case. It wasn't clear if there was a better way to
260        # detect a map field via the descriptor.
261        dst_field = getattr(dst, field_descriptor.name)
262        if hasattr(dst_field, 'add'):
263          next_dst = dst_field.add()
264          __apply_public_replication_internal(next_src, next_dst,
265                                              visited_messages)
266          # If the newly added field doesn't have any fields set, remove it to
267          # avoid creating many empty messages on dst.
268          if not dst_field[-1].ByteSize():
269            dst_field.pop()
270    else:
271      # For non-repeated fields, get the field in src and dst and call
272      # __apply_public_replication_internal.
273      next_src = getattr(src, field_descriptor.name)
274      next_dst = getattr(dst, field_descriptor.name)
275      __apply_public_replication_internal(next_src, next_dst, visited_messages)
276      # If the newly added field doesn't have any fields set, remove it to
277      # avoid creating many empty messages on dst, unless it was also empty on
278      # src, since the presence of message fields may be meaningful in some
279      # cases. Additionally, this avoids clobbering fields nested within a
280      # message in a oneof where the currently-set oneof is a different field.
281      if not next_dst.ByteSize() and next_src.ByteSize():
282        dst.ClearField(field_descriptor.name)
283