1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for compiling and importing Python protos on the fly.""" 15 16from collections.abc import Mapping 17import importlib.util 18import logging 19import os 20from pathlib import Path 21import subprocess 22import shlex 23import tempfile 24from types import ModuleType 25from typing import ( 26 Dict, 27 Generic, 28 Iterable, 29 Iterator, 30 List, 31 NamedTuple, 32 Optional, 33 Set, 34 Tuple, 35 TypeVar, 36 Union, 37) 38 39try: 40 # pylint: disable=wrong-import-position 41 import black 42 43 black_mode: Optional[black.Mode] = black.Mode(string_normalization=False) 44 45 # pylint: enable=wrong-import-position 46except ImportError: 47 black = None # type: ignore 48 black_mode = None 49 50_LOG = logging.getLogger(__name__) 51 52PathOrStr = Union[Path, str] 53 54 55def compile_protos( 56 output_dir: PathOrStr, 57 proto_files: Iterable[PathOrStr], 58 includes: Iterable[PathOrStr] = (), 59) -> None: 60 """Compiles proto files for Python by invoking the protobuf compiler. 61 62 Proto files not covered by one of the provided include paths will have their 63 directory added as an include path. 64 """ 65 proto_paths: List[Path] = [Path(f).resolve() for f in proto_files] 66 include_paths: Set[Path] = set(Path(d).resolve() for d in includes) 67 68 for path in proto_paths: 69 if not any(include in path.parents for include in include_paths): 70 include_paths.add(path.parent) 71 72 cmd: Tuple[PathOrStr, ...] = ( 73 'protoc', 74 '--experimental_allow_proto3_optional', 75 '--python_out', 76 os.path.abspath(output_dir), 77 *(f'-I{d}' for d in include_paths), 78 *proto_paths, 79 ) 80 81 _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd)) 82 process = subprocess.run(cmd, capture_output=True) 83 84 if process.returncode: 85 _LOG.error( 86 'protoc invocation failed!\n%s\n%s', 87 ' '.join(shlex.quote(str(c)) for c in cmd), 88 process.stderr.decode(), 89 ) 90 process.check_returncode() 91 92 93def _import_module(name: str, path: str) -> ModuleType: 94 spec = importlib.util.spec_from_file_location(name, path) 95 assert spec is not None 96 module = importlib.util.module_from_spec(spec) 97 spec.loader.exec_module(module) # type: ignore[union-attr] 98 return module 99 100 101def import_modules(directory: PathOrStr) -> Iterator: 102 """Imports modules in a directory and yields them.""" 103 parent = os.path.dirname(directory) 104 105 for dirpath, _, files in os.walk(directory): 106 path_parts = os.path.relpath(dirpath, parent).split(os.sep) 107 108 for file in files: 109 name, ext = os.path.splitext(file) 110 111 if ext == '.py': 112 yield _import_module( 113 f'{".".join(path_parts)}.{name}', 114 os.path.join(dirpath, file), 115 ) 116 117 118def compile_and_import( 119 proto_files: Iterable[PathOrStr], 120 includes: Iterable[PathOrStr] = (), 121 output_dir: Optional[PathOrStr] = None, 122) -> Iterator: 123 """Compiles protos and imports their modules; yields the proto modules. 124 125 Args: 126 proto_files: paths to .proto files to compile 127 includes: include paths to use for .proto compilation 128 output_dir: where to place the generated modules; a temporary directory is 129 used if omitted 130 131 Yields: 132 the generated protobuf Python modules 133 """ 134 135 if output_dir: 136 compile_protos(output_dir, proto_files, includes) 137 yield from import_modules(output_dir) 138 else: 139 with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir: 140 compile_protos(tempdir, proto_files, includes) 141 yield from import_modules(tempdir) 142 143 144def compile_and_import_file( 145 proto_file: PathOrStr, 146 includes: Iterable[PathOrStr] = (), 147 output_dir: Optional[PathOrStr] = None, 148): 149 """Compiles and imports the module for a single .proto file.""" 150 return next(iter(compile_and_import([proto_file], includes, output_dir))) 151 152 153def compile_and_import_strings( 154 contents: Iterable[str], 155 includes: Iterable[PathOrStr] = (), 156 output_dir: Optional[PathOrStr] = None, 157) -> Iterator: 158 """Compiles protos in one or more strings.""" 159 160 if isinstance(contents, str): 161 contents = [contents] 162 163 with tempfile.TemporaryDirectory(prefix='proto_sources_') as path: 164 protos = [] 165 166 for proto in contents: 167 # Use a hash of the proto so the same contents map to the same file 168 # name. The protobuf package complains if it seems the same contents 169 # in files with different names. 170 protos.append(Path(path, f'protobuf_{hash(proto):x}.proto')) 171 protos[-1].write_text(proto) 172 173 yield from compile_and_import(protos, includes, output_dir) 174 175 176T = TypeVar('T') 177 178 179class _NestedPackage(Generic[T]): 180 """Facilitates navigating protobuf packages as attributes.""" 181 182 def __init__(self, package: str): 183 self._packages: Dict[str, _NestedPackage[T]] = {} 184 self._items: List[T] = [] 185 self._package = package 186 187 def _add_package(self, subpackage: str, package: '_NestedPackage') -> None: 188 self._packages[subpackage] = package 189 190 def _add_item(self, item) -> None: 191 if item not in self._items: # Don't store the same item multiple times. 192 self._items.append(item) 193 194 def __getattr__(self, attr: str): 195 """Look up subpackages or package members.""" 196 if attr in self._packages: 197 return self._packages[attr] 198 199 for item in self._items: 200 if hasattr(item, attr): 201 return getattr(item, attr) 202 203 raise AttributeError( 204 f'Proto package "{self._package}" does not contain "{attr}"' 205 ) 206 207 def __getitem__(self, subpackage: str) -> '_NestedPackage[T]': 208 """Support accessing nested packages by name.""" 209 result = self 210 211 for package in subpackage.split('.'): 212 result = result._packages[package] 213 214 return result 215 216 def __dir__(self) -> List[str]: 217 """List subpackages and members of modules as attributes.""" 218 attributes = list(self._packages) 219 220 for item in self._items: 221 for attr, value in vars(item).items(): 222 # Exclude private variables and modules from dir(). 223 if not attr.startswith('_') and not isinstance( 224 value, ModuleType 225 ): 226 attributes.append(attr) 227 228 return attributes 229 230 def __iter__(self) -> Iterator['_NestedPackage[T]']: 231 """Iterate over nested packages.""" 232 return iter(self._packages.values()) 233 234 def __repr__(self) -> str: 235 msg = [f'ProtoPackage({self._package!r}'] 236 237 public_members = [ 238 i 239 for i in vars(self) 240 if i not in self._packages and not i.startswith('_') 241 ] 242 if public_members: 243 msg.append(f'members={str(public_members)}') 244 245 if self._packages: 246 msg.append(f'subpackages={str(list(self._packages))}') 247 248 return ', '.join(msg) + ')' 249 250 def __str__(self) -> str: 251 return self._package 252 253 254class Packages(NamedTuple): 255 """Items in a protobuf package structure; returned from as_package.""" 256 257 items_by_package: Dict[str, List] 258 packages: _NestedPackage 259 260 261def as_packages( 262 items: Iterable[Tuple[str, T]], packages: Optional[Packages] = None 263) -> Packages: 264 """Places items in a proto-style package structure navigable by attributes. 265 266 Args: 267 items: (package, item) tuples to insert into the package structure 268 packages: if provided, update this Packages instead of creating a new one 269 """ 270 if packages is None: 271 packages = Packages({}, _NestedPackage('')) 272 273 for package, item in items: 274 packages.items_by_package.setdefault(package, []).append(item) 275 276 entry = packages.packages 277 subpackages = package.split('.') 278 279 # pylint: disable=protected-access 280 for i, subpackage in enumerate(subpackages, 1): 281 if subpackage not in entry._packages: 282 entry._add_package( 283 subpackage, _NestedPackage('.'.join(subpackages[:i])) 284 ) 285 286 entry = entry._packages[subpackage] 287 288 entry._add_item(item) 289 # pylint: enable=protected-access 290 291 return packages 292 293 294PathOrModule = Union[str, Path, ModuleType] 295 296 297class Library: 298 """A collection of protocol buffer modules sorted by package. 299 300 In Python, each .proto file is compiled into a Python module. The Library 301 class makes it simple to navigate a collection of Python modules 302 corresponding to .proto files, without relying on the location of these 303 compiled modules. 304 305 Proto messages and other types can be directly accessed by their protocol 306 buffer package name. For example, the foo.bar.Baz message can be accessed 307 in a Library called `protos` as: 308 309 protos.packages.foo.bar.Baz 310 311 A Library also provides the modules_by_package dictionary, for looking up 312 the list of modules in a particular package, and the modules() generator 313 for iterating over all modules. 314 """ 315 316 @classmethod 317 def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library': 318 """Creates a Library from paths to proto files or proto modules.""" 319 paths: List[PathOrStr] = [] 320 modules: List[ModuleType] = [] 321 322 for proto in protos: 323 if isinstance(proto, (Path, str)): 324 paths.append(proto) 325 else: 326 modules.append(proto) 327 328 if paths: 329 modules += compile_and_import(paths) 330 return Library(modules) 331 332 @classmethod 333 def from_strings( 334 cls, 335 contents: Iterable[str], 336 includes: Iterable[PathOrStr] = (), 337 output_dir: Optional[PathOrStr] = None, 338 ) -> 'Library': 339 """Creates a proto library from protos in the provided strings.""" 340 return cls(compile_and_import_strings(contents, includes, output_dir)) 341 342 def __init__(self, modules: Iterable[ModuleType]): 343 """Constructs a Library from an iterable of modules. 344 345 A Library can be constructed with modules dynamically compiled by 346 compile_and_import. For example: 347 348 protos = Library(compile_and_import(list_of_proto_files)) 349 """ 350 self.modules_by_package, self.packages = as_packages( 351 (m.DESCRIPTOR.package, m) # type: ignore[attr-defined] 352 for m in modules 353 ) 354 355 def modules(self) -> Iterable: 356 """Iterates over all protobuf modules in this library.""" 357 for module_list in self.modules_by_package.values(): 358 yield from module_list 359 360 def messages(self) -> Iterable: 361 """Iterates over all protobuf messages in this library.""" 362 for module in self.modules(): 363 yield from _nested_messages( 364 module, module.DESCRIPTOR.message_types_by_name 365 ) 366 367 368def _nested_messages(scope, message_names: Iterable[str]) -> Iterator: 369 for name in message_names: 370 msg = getattr(scope, name) 371 yield msg 372 yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name) 373 374 375def _repr_char(char: int) -> str: 376 r"""Returns an ASCII char or the \x code for non-printable values.""" 377 if ord(' ') <= char <= ord('~'): 378 return r"\'" if chr(char) == "'" else chr(char) 379 380 return f'\\x{char:02X}' 381 382 383def bytes_repr(value: bytes) -> str: 384 """Prints bytes as mixed ASCII only if at least half are printable.""" 385 ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value) 386 if ascii_char_count >= len(value) / 2: 387 contents = ''.join(_repr_char(c) for c in value) 388 else: 389 contents = ''.join(f'\\x{c:02X}' for c in value) 390 391 return f"b'{contents}'" 392 393 394def _field_repr(field, value) -> str: 395 if field.type == field.TYPE_ENUM: 396 try: 397 enum = field.enum_type.values_by_number[value] 398 return f'{field.enum_type.full_name}.{enum.name}' 399 except KeyError: 400 return repr(value) 401 402 if field.type == field.TYPE_MESSAGE: 403 return proto_repr(value) 404 405 if field.type == field.TYPE_BYTES: 406 return bytes_repr(value) 407 408 return repr(value) 409 410 411def _proto_repr(message) -> Iterator[str]: 412 for field in message.DESCRIPTOR.fields: 413 value = getattr(message, field.name) 414 415 # Skip fields that are not present. 416 try: 417 if not message.HasField(field.name): 418 continue 419 except ValueError: 420 # Skip default-valued fields that don't support HasField. 421 if ( 422 field.label != field.LABEL_REPEATED 423 and value == field.default_value 424 ): 425 continue 426 427 if field.label == field.LABEL_REPEATED: 428 if not value: 429 continue 430 431 if isinstance(value, Mapping): 432 key_desc, value_desc = field.message_type.fields 433 values = ', '.join( 434 f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}' 435 for k, v in value.items() 436 ) 437 yield f'{field.name}={{{values}}}' 438 else: 439 values = ', '.join(_field_repr(field, v) for v in value) 440 yield f'{field.name}=[{values}]' 441 else: 442 yield f'{field.name}={_field_repr(field, value)}' 443 444 445def proto_repr(message, *, wrap: bool = True) -> str: 446 """Creates a repr-like string for a protobuf. 447 448 In an interactive console that imports proto objects into the namespace, the 449 output of proto_repr() can be used as Python source to create a proto 450 object. 451 452 Args: 453 message: The protobuf message to format 454 wrap: If true and black is available, the output is wrapped according to 455 PEP8 using black. 456 """ 457 raw = f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})' 458 459 if wrap and black is not None and black_mode is not None: 460 return black.format_str(raw, mode=black_mode).strip() 461 462 return raw 463