1# python3 2# Copyright 2020 The ChromiumOS Authors 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5"""Proto-related helper functions.""" 6 7import importlib 8import os 9 10from typing import Any, Dict, List, Set, NamedTuple 11 12from google.protobuf import message as pb_message 13from google.protobuf import symbol_database 14from google.protobuf.descriptor import FieldDescriptor 15 16from chromiumos.config.public_replication import public_replication_pb2 17 18 19def _build_type_map(): 20 """Build a map from the TYPE_* constants in protobuf to string names.""" 21 mapping = {} 22 for field in dir(FieldDescriptor): 23 if field.startswith("TYPE_"): 24 mapping[getattr(FieldDescriptor, field)] = field[5:].lower() 25 return mapping 26 27 28type_strings = _build_type_map() 29 30 31class FieldInfo(NamedTuple): 32 """Named tuple to represent information about a proto field. 33 34 Fields: 35 name: Name of the field in the message 36 typeid: FieldDescriptor type for the field 37 typename: For compound fields, the name of the message type. 38 For value fields, a string-ified type name 39 repeated: True if the field is repeated 40 """ 41 name: str 42 typeid: int 43 typename: str 44 repeated: bool 45 46 47def create_symbol_db() -> symbol_database.SymbolDatabase(): 48 """Load any generated messages from python/ and return symbol database. 49 50 Messages auto-register when imported, so we just recursively import and 51 generate protobuffer files and return the default instance of SymbolDatabase. 52 53 Returns: 54 symbol_database.Default() 55 """ 56 57 def __import_modules(dirname: str, paths: [str]): 58 """Recurse through a list of paths and automatically load protobufs. 59 60 This starts with dirname and recursively descends looking for python files 61 ending with _pb2.py and loads them into the current name space. 62 63 Args: 64 dirname: Directory name to process 65 paths: list of parent paths names for current call 66 """ 67 for path in os.listdir(dirname): 68 full_path = os.path.join(dirname, path) 69 70 if os.path.isdir(full_path) and not os.path.islink(full_path): 71 __import_modules(full_path, paths + [path]) 72 elif os.path.isfile(full_path): 73 if path.endswith("_pb2.py"): 74 importlib.import_module('%s.%s' % ('.'.join(paths), path[:-3])) 75 76 __import_modules( 77 os.path.join(os.path.dirname(__file__), "../../python/chromiumos"), 78 ["chromiumos"], 79 ) 80 return symbol_database.Default() 81 82 83def resolve_field_path( 84 message: pb_message.Message, 85 path: str, 86) -> List[FieldInfo]: 87 """Resolve a dotted field path into specific information about the fields. 88 89 A field path is of the format foo.bar.baz where the dotted notation .field 90 indicates a particular field of the preceding message type. 91 92 We resolve this by iteratively looking up each field starting with a root 93 message type, and populating information about it. 94 95 The result is a list of FieldInfo instances, one per field. If a field 96 is not found, then None is returned for it and subsequent fields. 97 98 Args: 99 message: a protobuffer Message type to root the path in 100 path: a dotted field path 101 """ 102 103 fields = path.split('.') 104 infos = [None] * len(fields) 105 106 current = message.DESCRIPTOR 107 for idx, field in enumerate(fields): 108 descriptor = current.fields_by_name.get(field) 109 if not descriptor: 110 break 111 112 typename = type_strings[descriptor.type] 113 if descriptor.message_type: 114 typename = descriptor.message_type.name 115 116 infos[idx] = FieldInfo( 117 name=field, 118 typeid=descriptor.type, 119 typename=typename, 120 repeated=(descriptor.label == descriptor.LABEL_REPEATED), 121 ) 122 123 # If field isn't a message type then we're done 124 if descriptor.type != descriptor.TYPE_MESSAGE: 125 break 126 127 current = descriptor.message_type 128 129 return infos 130 131 132def get_all_fields(message: pb_message.Message) -> List[Any]: 133 """Returns the value of each field in message. 134 135 Note that the value, not the name, is returned. For example, given a Timestamp 136 message: 137 138 { 139 "seconds": 1, 140 "nanos": 2 141 } 142 143 the result is [1, 2]. 144 """ 145 fields = [] 146 for field_descriptor in message.DESCRIPTOR.fields: 147 fields.append(getattr(message, field_descriptor.name)) 148 return fields 149 150 151def get_dep_graph(message: pb_message.Message) -> Dict[str, List[str]]: 152 """Compute the dep graph of a message. 153 154 This is a sparse representation of a DAG as a dict where the key is the 155 name of a message, and the value is the list of that message's 156 dependencies. 157 """ 158 159 def helper(dep_graph, descriptor): 160 deps = set() 161 for field in descriptor.fields: 162 if field.type == field.TYPE_MESSAGE: 163 deps.add(field.message_type.full_name) 164 helper(dep_graph, field.message_type) 165 dep_graph[descriptor.full_name] = sorted(deps) 166 167 graph = {} 168 helper(graph, message.DESCRIPTOR) 169 return graph 170 171 172def get_dep_order(message: pb_message.Message) -> List[str]: 173 """Compute dependency order of protobuf type and its dependencies. 174 175 This is a list from a preorder traversal of the dependency graph above. 176 """ 177 178 def dfs(graph, callback, node, seen=None): 179 if seen is None: 180 seen = set() 181 182 if node in graph: 183 for child in graph[node]: 184 dfs(graph, callback, child, seen) 185 186 if not node in seen: 187 callback(node) 188 seen.add(node) 189 190 deps = [] 191 dfs(get_dep_graph(message), deps.append, message.DESCRIPTOR.full_name) 192 return deps 193 194 195def apply_public_replication(src: pb_message.Message, dst: pb_message.Message): 196 """Traverses src and merges fields to dst when a PublicReplication message is 197 found. 198 199 See the comment on the PublicReplication proto for a complete description of 200 the semantics of PublicReplication. 201 202 Args: 203 src: Source message. 204 dst: Destination message to be merged into. 205 """ 206 __apply_public_replication_internal(src, dst, visited_messages=set()) 207 208 209def __apply_public_replication_internal(src: pb_message.Message, 210 dst: pb_message.Message, 211 visited_messages=Set[str]): 212 """Private function to do most of the work of apply_public_replication. 213 214 Allows bookkeeping information to be passed between recursive calls, along 215 with the src and dst messages. 216 217 Args: 218 src: Source message. 219 dst: Destination message to be merged into. 220 visited_messages: Set of the full message names that have been visited 221 higher in the call stack. Used to prevent infinite recursion if a message 222 references itself. 223 """ 224 if src.DESCRIPTOR.full_name != dst.DESCRIPTOR.full_name: 225 raise ValueError( 226 'src and dst must be the same message type. Got ' 227 f'{src.DESCRIPTOR.full_name} and {dst.DESCRIPTOR.full_name}') 228 229 # If a message with the same type as src has been visited higher in the call 230 # stack, stop here. Then, create a new set with src's name in it. Note that 231 # it doesn't work to add src's name to visited_messages, because a message 232 # with the same type as src might be visited after this function exists, e.g. 233 # in a list of messages with the same type. 234 if src.DESCRIPTOR.full_name in visited_messages: 235 return 236 237 visited_messages = visited_messages.union([src.DESCRIPTOR.full_name]) 238 239 # First, see if any of the fields on src are a PublicReplication message. If 240 # one is found, apply the field mask within it. 241 for field_descriptor in src.DESCRIPTOR.fields: 242 message_descriptor = field_descriptor.message_type 243 if (message_descriptor and message_descriptor.full_name 244 == public_replication_pb2.PublicReplication.DESCRIPTOR.full_name): 245 public_fields = getattr(src, field_descriptor.name).public_fields 246 public_fields.MergeMessage(src, dst) 247 248 # Iterate the fields of src and call __apply_public_replication_internal 249 # recursively. 250 for field_descriptor in src.DESCRIPTOR.fields: 251 if field_descriptor.type != field_descriptor.TYPE_MESSAGE: 252 continue 253 254 if field_descriptor.label == field_descriptor.LABEL_REPEATED: 255 # For repeated fields, for each message in src, create a new message in 256 # dst and call __apply_public_replication_internal. 257 for next_src in getattr(src, field_descriptor.name): 258 # map fields are considered repeated messages, but do not have an 'add' 259 # method. Skip this case. It wasn't clear if there was a better way to 260 # detect a map field via the descriptor. 261 dst_field = getattr(dst, field_descriptor.name) 262 if hasattr(dst_field, 'add'): 263 next_dst = dst_field.add() 264 __apply_public_replication_internal(next_src, next_dst, 265 visited_messages) 266 # If the newly added field doesn't have any fields set, remove it to 267 # avoid creating many empty messages on dst. 268 if not dst_field[-1].ByteSize(): 269 dst_field.pop() 270 else: 271 # For non-repeated fields, get the field in src and dst and call 272 # __apply_public_replication_internal. 273 next_src = getattr(src, field_descriptor.name) 274 next_dst = getattr(dst, field_descriptor.name) 275 __apply_public_replication_internal(next_src, next_dst, visited_messages) 276 # If the newly added field doesn't have any fields set, remove it to 277 # avoid creating many empty messages on dst, unless it was also empty on 278 # src, since the presence of message fields may be meaningful in some 279 # cases. Additionally, this avoids clobbering fields nested within a 280 # message in a oneof where the currently-set oneof is a different field. 281 if not next_dst.ByteSize() and next_src.ByteSize(): 282 dst.ClearField(field_descriptor.name) 283