1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for compiling and importing Python protos on the fly.""" 15 16from collections.abc import Mapping 17import importlib.util 18import logging 19import os 20from pathlib import Path 21import subprocess 22import shlex 23import tempfile 24from types import ModuleType 25from typing import (Dict, Generic, Iterable, Iterator, List, NamedTuple, Set, 26 Tuple, TypeVar, Union) 27 28_LOG = logging.getLogger(__name__) 29 30PathOrStr = Union[Path, str] 31 32 33def compile_protos( 34 output_dir: PathOrStr, 35 proto_files: Iterable[PathOrStr], 36 includes: Iterable[PathOrStr] = ()) -> None: 37 """Compiles proto files for Python by invoking the protobuf compiler. 38 39 Proto files not covered by one of the provided include paths will have their 40 directory added as an include path. 41 """ 42 proto_paths: List[Path] = [Path(f).resolve() for f in proto_files] 43 include_paths: Set[Path] = set(Path(d).resolve() for d in includes) 44 45 for path in proto_paths: 46 if not any(include in path.parents for include in include_paths): 47 include_paths.add(path.parent) 48 49 cmd: Tuple[PathOrStr, ...] = ( 50 'protoc', 51 '--experimental_allow_proto3_optional', 52 '--python_out', 53 os.path.abspath(output_dir), 54 *(f'-I{d}' for d in include_paths), 55 *proto_paths, 56 ) 57 58 _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd)) 59 process = subprocess.run(cmd, capture_output=True) 60 61 if process.returncode: 62 _LOG.error('protoc invocation failed!\n%s\n%s', 63 ' '.join(shlex.quote(str(c)) for c in cmd), 64 process.stderr.decode()) 65 process.check_returncode() 66 67 68def _import_module(name: str, path: str) -> ModuleType: 69 spec = importlib.util.spec_from_file_location(name, path) 70 module = importlib.util.module_from_spec(spec) 71 spec.loader.exec_module(module) # type: ignore[union-attr] 72 return module 73 74 75def import_modules(directory: PathOrStr) -> Iterator: 76 """Imports modules in a directory and yields them.""" 77 parent = os.path.dirname(directory) 78 79 for dirpath, _, files in os.walk(directory): 80 path_parts = os.path.relpath(dirpath, parent).split(os.sep) 81 82 for file in files: 83 name, ext = os.path.splitext(file) 84 85 if ext == '.py': 86 yield _import_module(f'{".".join(path_parts)}.{name}', 87 os.path.join(dirpath, file)) 88 89 90def compile_and_import(proto_files: Iterable[PathOrStr], 91 includes: Iterable[PathOrStr] = (), 92 output_dir: PathOrStr = None) -> Iterator: 93 """Compiles protos and imports their modules; yields the proto modules. 94 95 Args: 96 proto_files: paths to .proto files to compile 97 includes: include paths to use for .proto compilation 98 output_dir: where to place the generated modules; a temporary directory is 99 used if omitted 100 101 Yields: 102 the generated protobuf Python modules 103 """ 104 105 if output_dir: 106 compile_protos(output_dir, proto_files, includes) 107 yield from import_modules(output_dir) 108 else: 109 with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir: 110 compile_protos(tempdir, proto_files, includes) 111 yield from import_modules(tempdir) 112 113 114def compile_and_import_file(proto_file: PathOrStr, 115 includes: Iterable[PathOrStr] = (), 116 output_dir: PathOrStr = None): 117 """Compiles and imports the module for a single .proto file.""" 118 return next(iter(compile_and_import([proto_file], includes, output_dir))) 119 120 121def compile_and_import_strings(contents: Iterable[str], 122 includes: Iterable[PathOrStr] = (), 123 output_dir: PathOrStr = None) -> Iterator: 124 """Compiles protos in one or more strings.""" 125 126 if isinstance(contents, str): 127 contents = [contents] 128 129 with tempfile.TemporaryDirectory(prefix='proto_sources_') as path: 130 protos = [] 131 132 for proto in contents: 133 # Use a hash of the proto so the same contents map to the same file 134 # name. The protobuf package complains if it seems the same contents 135 # in files with different names. 136 protos.append(Path(path, f'protobuf_{hash(proto):x}.proto')) 137 protos[-1].write_text(proto) 138 139 yield from compile_and_import(protos, includes, output_dir) 140 141 142T = TypeVar('T') 143 144 145class _NestedPackage(Generic[T]): 146 """Facilitates navigating protobuf packages as attributes.""" 147 def __init__(self, package: str): 148 self._packages: Dict[str, _NestedPackage[T]] = {} 149 self._items: List[T] = [] 150 self._package = package 151 152 def _add_package(self, subpackage: str, package: '_NestedPackage') -> None: 153 self._packages[subpackage] = package 154 155 def _add_item(self, item) -> None: 156 if item not in self._items: # Don't store the same item multiple times. 157 self._items.append(item) 158 159 def __getattr__(self, attr: str): 160 """Look up subpackages or package members.""" 161 if attr in self._packages: 162 return self._packages[attr] 163 164 for item in self._items: 165 if hasattr(item, attr): 166 return getattr(item, attr) 167 168 raise AttributeError( 169 f'Proto package "{self._package}" does not contain "{attr}"') 170 171 def __getitem__(self, subpackage: str) -> '_NestedPackage[T]': 172 """Support accessing nested packages by name.""" 173 result = self 174 175 for package in subpackage.split('.'): 176 result = result._packages[package] 177 178 return result 179 180 def __dir__(self) -> List[str]: 181 """List subpackages and members of modules as attributes.""" 182 attributes = list(self._packages) 183 184 for item in self._items: 185 for attr, value in vars(item).items(): 186 # Exclude private variables and modules from dir(). 187 if not attr.startswith('_') and not isinstance( 188 value, ModuleType): 189 attributes.append(attr) 190 191 return attributes 192 193 def __iter__(self) -> Iterator['_NestedPackage[T]']: 194 """Iterate over nested packages.""" 195 return iter(self._packages.values()) 196 197 def __repr__(self) -> str: 198 msg = [f'ProtoPackage({self._package!r}'] 199 200 public_members = [ 201 i for i in vars(self) 202 if i not in self._packages and not i.startswith('_') 203 ] 204 if public_members: 205 msg.append(f'members={str(public_members)}') 206 207 if self._packages: 208 msg.append(f'subpackages={str(list(self._packages))}') 209 210 return ', '.join(msg) + ')' 211 212 def __str__(self) -> str: 213 return self._package 214 215 216class Packages(NamedTuple): 217 """Items in a protobuf package structure; returned from as_package.""" 218 items_by_package: Dict[str, List] 219 packages: _NestedPackage 220 221 222def as_packages(items: Iterable[Tuple[str, T]], 223 packages: Packages = None) -> Packages: 224 """Places items in a proto-style package structure navigable by attributes. 225 226 Args: 227 items: (package, item) tuples to insert into the package structure 228 packages: if provided, update this Packages instead of creating a new one 229 """ 230 if packages is None: 231 packages = Packages({}, _NestedPackage('')) 232 233 for package, item in items: 234 packages.items_by_package.setdefault(package, []).append(item) 235 236 entry = packages.packages 237 subpackages = package.split('.') 238 239 # pylint: disable=protected-access 240 for i, subpackage in enumerate(subpackages, 1): 241 if subpackage not in entry._packages: 242 entry._add_package(subpackage, 243 _NestedPackage('.'.join(subpackages[:i]))) 244 245 entry = entry._packages[subpackage] 246 247 entry._add_item(item) 248 # pylint: enable=protected-access 249 250 return packages 251 252 253PathOrModule = Union[str, Path, ModuleType] 254 255 256class Library: 257 """A collection of protocol buffer modules sorted by package. 258 259 In Python, each .proto file is compiled into a Python module. The Library 260 class makes it simple to navigate a collection of Python modules 261 corresponding to .proto files, without relying on the location of these 262 compiled modules. 263 264 Proto messages and other types can be directly accessed by their protocol 265 buffer package name. For example, the foo.bar.Baz message can be accessed 266 in a Library called `protos` as: 267 268 protos.packages.foo.bar.Baz 269 270 A Library also provides the modules_by_package dictionary, for looking up 271 the list of modules in a particular package, and the modules() generator 272 for iterating over all modules. 273 """ 274 @classmethod 275 def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library': 276 """Creates a Library from paths to proto files or proto modules.""" 277 paths: List[PathOrStr] = [] 278 modules: List[ModuleType] = [] 279 280 for proto in protos: 281 if isinstance(proto, (Path, str)): 282 paths.append(proto) 283 else: 284 modules.append(proto) 285 286 if paths: 287 modules += compile_and_import(paths) 288 return Library(modules) 289 290 @classmethod 291 def from_strings(cls, 292 contents: Iterable[str], 293 includes: Iterable[PathOrStr] = (), 294 output_dir: PathOrStr = None) -> 'Library': 295 """Creates a proto library from protos in the provided strings.""" 296 return cls(compile_and_import_strings(contents, includes, output_dir)) 297 298 def __init__(self, modules: Iterable[ModuleType]): 299 """Constructs a Library from an iterable of modules. 300 301 A Library can be constructed with modules dynamically compiled by 302 compile_and_import. For example: 303 304 protos = Library(compile_and_import(list_of_proto_files)) 305 """ 306 self.modules_by_package, self.packages = as_packages( 307 (m.DESCRIPTOR.package, m) # type: ignore[attr-defined] 308 for m in modules) 309 310 def modules(self) -> Iterable: 311 """Iterates over all protobuf modules in this library.""" 312 for module_list in self.modules_by_package.values(): 313 yield from module_list 314 315 def messages(self) -> Iterable: 316 """Iterates over all protobuf messages in this library.""" 317 for module in self.modules(): 318 yield from _nested_messages( 319 module, module.DESCRIPTOR.message_types_by_name) 320 321 322def _nested_messages(scope, message_names: Iterable[str]) -> Iterator: 323 for name in message_names: 324 msg = getattr(scope, name) 325 yield msg 326 yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name) 327 328 329def _repr_char(char: int) -> str: 330 r"""Returns an ASCII char or the \x code for non-printable values.""" 331 if ord(' ') <= char <= ord('~'): 332 return r"\'" if chr(char) == "'" else chr(char) 333 334 return f'\\x{char:02X}' 335 336 337def bytes_repr(value: bytes) -> str: 338 """Prints bytes as mixed ASCII only if at least half are printable.""" 339 ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value) 340 if ascii_char_count >= len(value) / 2: 341 contents = ''.join(_repr_char(c) for c in value) 342 else: 343 contents = ''.join(f'\\x{c:02X}' for c in value) 344 345 return f"b'{contents}'" 346 347 348def _field_repr(field, value) -> str: 349 if field.type == field.TYPE_ENUM: 350 try: 351 enum = field.enum_type.values_by_number[value] 352 return f'{field.enum_type.full_name}.{enum.name}' 353 except KeyError: 354 return repr(value) 355 356 if field.type == field.TYPE_MESSAGE: 357 return proto_repr(value) 358 359 if field.type == field.TYPE_BYTES: 360 return bytes_repr(value) 361 362 return repr(value) 363 364 365def _proto_repr(message) -> Iterator[str]: 366 for field in message.DESCRIPTOR.fields: 367 value = getattr(message, field.name) 368 369 # Skip fields that are not present. 370 try: 371 if not message.HasField(field.name): 372 continue 373 except ValueError: 374 # Skip default-valued fields that don't support HasField. 375 if (field.label != field.LABEL_REPEATED 376 and value == field.default_value): 377 continue 378 379 if field.label == field.LABEL_REPEATED: 380 if not value: 381 continue 382 383 if isinstance(value, Mapping): 384 key_desc, value_desc = field.message_type.fields 385 values = ', '.join( 386 f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}' 387 for k, v in value.items()) 388 yield f'{field.name}={{{values}}}' 389 else: 390 values = ', '.join(_field_repr(field, v) for v in value) 391 yield f'{field.name}=[{values}]' 392 else: 393 yield f'{field.name}={_field_repr(field, value)}' 394 395 396def proto_repr(message) -> str: 397 """Creates a repr-like string for a protobuf.""" 398 return f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})' 399