• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for compiling and importing Python protos on the fly."""
15
16from collections.abc import Mapping
17import importlib.util
18import logging
19import os
20from pathlib import Path
21import subprocess
22import shlex
23import tempfile
24from types import ModuleType
25from typing import (
26    Dict,
27    Generic,
28    Iterable,
29    Iterator,
30    List,
31    NamedTuple,
32    Optional,
33    Set,
34    Tuple,
35    TypeVar,
36    Union,
37)
38
39try:
40    # pylint: disable=wrong-import-position
41    import black
42
43    black_mode: Optional[black.Mode] = black.Mode(string_normalization=False)
44
45    # pylint: enable=wrong-import-position
46except ImportError:
47    black = None  # type: ignore
48    black_mode = None
49
50_LOG = logging.getLogger(__name__)
51
52PathOrStr = Union[Path, str]
53
54
55def compile_protos(
56    output_dir: PathOrStr,
57    proto_files: Iterable[PathOrStr],
58    includes: Iterable[PathOrStr] = (),
59) -> None:
60    """Compiles proto files for Python by invoking the protobuf compiler.
61
62    Proto files not covered by one of the provided include paths will have their
63    directory added as an include path.
64    """
65    proto_paths: List[Path] = [Path(f).resolve() for f in proto_files]
66    include_paths: Set[Path] = set(Path(d).resolve() for d in includes)
67
68    for path in proto_paths:
69        if not any(include in path.parents for include in include_paths):
70            include_paths.add(path.parent)
71
72    cmd: Tuple[PathOrStr, ...] = (
73        'protoc',
74        '--experimental_allow_proto3_optional',
75        '--python_out',
76        os.path.abspath(output_dir),
77        *(f'-I{d}' for d in include_paths),
78        *proto_paths,
79    )
80
81    _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd))
82    process = subprocess.run(cmd, capture_output=True)
83
84    if process.returncode:
85        _LOG.error(
86            'protoc invocation failed!\n%s\n%s',
87            ' '.join(shlex.quote(str(c)) for c in cmd),
88            process.stderr.decode(),
89        )
90        process.check_returncode()
91
92
93def _import_module(name: str, path: str) -> ModuleType:
94    spec = importlib.util.spec_from_file_location(name, path)
95    assert spec is not None
96    module = importlib.util.module_from_spec(spec)
97    spec.loader.exec_module(module)  # type: ignore[union-attr]
98    return module
99
100
101def import_modules(directory: PathOrStr) -> Iterator:
102    """Imports modules in a directory and yields them."""
103    parent = os.path.dirname(directory)
104
105    for dirpath, _, files in os.walk(directory):
106        path_parts = os.path.relpath(dirpath, parent).split(os.sep)
107
108        for file in files:
109            name, ext = os.path.splitext(file)
110
111            if ext == '.py':
112                yield _import_module(
113                    f'{".".join(path_parts)}.{name}',
114                    os.path.join(dirpath, file),
115                )
116
117
118def compile_and_import(
119    proto_files: Iterable[PathOrStr],
120    includes: Iterable[PathOrStr] = (),
121    output_dir: Optional[PathOrStr] = None,
122) -> Iterator:
123    """Compiles protos and imports their modules; yields the proto modules.
124
125    Args:
126      proto_files: paths to .proto files to compile
127      includes: include paths to use for .proto compilation
128      output_dir: where to place the generated modules; a temporary directory is
129          used if omitted
130
131    Yields:
132      the generated protobuf Python modules
133    """
134
135    if output_dir:
136        compile_protos(output_dir, proto_files, includes)
137        yield from import_modules(output_dir)
138    else:
139        with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir:
140            compile_protos(tempdir, proto_files, includes)
141            yield from import_modules(tempdir)
142
143
144def compile_and_import_file(
145    proto_file: PathOrStr,
146    includes: Iterable[PathOrStr] = (),
147    output_dir: Optional[PathOrStr] = None,
148):
149    """Compiles and imports the module for a single .proto file."""
150    return next(iter(compile_and_import([proto_file], includes, output_dir)))
151
152
153def compile_and_import_strings(
154    contents: Iterable[str],
155    includes: Iterable[PathOrStr] = (),
156    output_dir: Optional[PathOrStr] = None,
157) -> Iterator:
158    """Compiles protos in one or more strings."""
159
160    if isinstance(contents, str):
161        contents = [contents]
162
163    with tempfile.TemporaryDirectory(prefix='proto_sources_') as path:
164        protos = []
165
166        for proto in contents:
167            # Use a hash of the proto so the same contents map to the same file
168            # name. The protobuf package complains if it seems the same contents
169            # in files with different names.
170            protos.append(Path(path, f'protobuf_{hash(proto):x}.proto'))
171            protos[-1].write_text(proto)
172
173        yield from compile_and_import(protos, includes, output_dir)
174
175
176T = TypeVar('T')
177
178
179class _NestedPackage(Generic[T]):
180    """Facilitates navigating protobuf packages as attributes."""
181
182    def __init__(self, package: str):
183        self._packages: Dict[str, _NestedPackage[T]] = {}
184        self._items: List[T] = []
185        self._package = package
186
187    def _add_package(self, subpackage: str, package: '_NestedPackage') -> None:
188        self._packages[subpackage] = package
189
190    def _add_item(self, item) -> None:
191        if item not in self._items:  # Don't store the same item multiple times.
192            self._items.append(item)
193
194    def __getattr__(self, attr: str):
195        """Look up subpackages or package members."""
196        if attr in self._packages:
197            return self._packages[attr]
198
199        for item in self._items:
200            if hasattr(item, attr):
201                return getattr(item, attr)
202
203        raise AttributeError(
204            f'Proto package "{self._package}" does not contain "{attr}"'
205        )
206
207    def __getitem__(self, subpackage: str) -> '_NestedPackage[T]':
208        """Support accessing nested packages by name."""
209        result = self
210
211        for package in subpackage.split('.'):
212            result = result._packages[package]
213
214        return result
215
216    def __dir__(self) -> List[str]:
217        """List subpackages and members of modules as attributes."""
218        attributes = list(self._packages)
219
220        for item in self._items:
221            for attr, value in vars(item).items():
222                # Exclude private variables and modules from dir().
223                if not attr.startswith('_') and not isinstance(
224                    value, ModuleType
225                ):
226                    attributes.append(attr)
227
228        return attributes
229
230    def __iter__(self) -> Iterator['_NestedPackage[T]']:
231        """Iterate over nested packages."""
232        return iter(self._packages.values())
233
234    def __repr__(self) -> str:
235        msg = [f'ProtoPackage({self._package!r}']
236
237        public_members = [
238            i
239            for i in vars(self)
240            if i not in self._packages and not i.startswith('_')
241        ]
242        if public_members:
243            msg.append(f'members={str(public_members)}')
244
245        if self._packages:
246            msg.append(f'subpackages={str(list(self._packages))}')
247
248        return ', '.join(msg) + ')'
249
250    def __str__(self) -> str:
251        return self._package
252
253
254class Packages(NamedTuple):
255    """Items in a protobuf package structure; returned from as_package."""
256
257    items_by_package: Dict[str, List]
258    packages: _NestedPackage
259
260
261def as_packages(
262    items: Iterable[Tuple[str, T]], packages: Optional[Packages] = None
263) -> Packages:
264    """Places items in a proto-style package structure navigable by attributes.
265
266    Args:
267      items: (package, item) tuples to insert into the package structure
268      packages: if provided, update this Packages instead of creating a new one
269    """
270    if packages is None:
271        packages = Packages({}, _NestedPackage(''))
272
273    for package, item in items:
274        packages.items_by_package.setdefault(package, []).append(item)
275
276        entry = packages.packages
277        subpackages = package.split('.')
278
279        # pylint: disable=protected-access
280        for i, subpackage in enumerate(subpackages, 1):
281            if subpackage not in entry._packages:
282                entry._add_package(
283                    subpackage, _NestedPackage('.'.join(subpackages[:i]))
284                )
285
286            entry = entry._packages[subpackage]
287
288        entry._add_item(item)
289        # pylint: enable=protected-access
290
291    return packages
292
293
294PathOrModule = Union[str, Path, ModuleType]
295
296
297class Library:
298    """A collection of protocol buffer modules sorted by package.
299
300    In Python, each .proto file is compiled into a Python module. The Library
301    class makes it simple to navigate a collection of Python modules
302    corresponding to .proto files, without relying on the location of these
303    compiled modules.
304
305    Proto messages and other types can be directly accessed by their protocol
306    buffer package name. For example, the foo.bar.Baz message can be accessed
307    in a Library called `protos` as:
308
309      protos.packages.foo.bar.Baz
310
311    A Library also provides the modules_by_package dictionary, for looking up
312    the list of modules in a particular package, and the modules() generator
313    for iterating over all modules.
314    """
315
316    @classmethod
317    def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library':
318        """Creates a Library from paths to proto files or proto modules."""
319        paths: List[PathOrStr] = []
320        modules: List[ModuleType] = []
321
322        for proto in protos:
323            if isinstance(proto, (Path, str)):
324                paths.append(proto)
325            else:
326                modules.append(proto)
327
328        if paths:
329            modules += compile_and_import(paths)
330        return Library(modules)
331
332    @classmethod
333    def from_strings(
334        cls,
335        contents: Iterable[str],
336        includes: Iterable[PathOrStr] = (),
337        output_dir: Optional[PathOrStr] = None,
338    ) -> 'Library':
339        """Creates a proto library from protos in the provided strings."""
340        return cls(compile_and_import_strings(contents, includes, output_dir))
341
342    def __init__(self, modules: Iterable[ModuleType]):
343        """Constructs a Library from an iterable of modules.
344
345        A Library can be constructed with modules dynamically compiled by
346        compile_and_import. For example:
347
348            protos = Library(compile_and_import(list_of_proto_files))
349        """
350        self.modules_by_package, self.packages = as_packages(
351            (m.DESCRIPTOR.package, m)  # type: ignore[attr-defined]
352            for m in modules
353        )
354
355    def modules(self) -> Iterable:
356        """Iterates over all protobuf modules in this library."""
357        for module_list in self.modules_by_package.values():
358            yield from module_list
359
360    def messages(self) -> Iterable:
361        """Iterates over all protobuf messages in this library."""
362        for module in self.modules():
363            yield from _nested_messages(
364                module, module.DESCRIPTOR.message_types_by_name
365            )
366
367
368def _nested_messages(scope, message_names: Iterable[str]) -> Iterator:
369    for name in message_names:
370        msg = getattr(scope, name)
371        yield msg
372        yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name)
373
374
375def _repr_char(char: int) -> str:
376    r"""Returns an ASCII char or the \x code for non-printable values."""
377    if ord(' ') <= char <= ord('~'):
378        return r"\'" if chr(char) == "'" else chr(char)
379
380    return f'\\x{char:02X}'
381
382
383def bytes_repr(value: bytes) -> str:
384    """Prints bytes as mixed ASCII only if at least half are printable."""
385    ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value)
386    if ascii_char_count >= len(value) / 2:
387        contents = ''.join(_repr_char(c) for c in value)
388    else:
389        contents = ''.join(f'\\x{c:02X}' for c in value)
390
391    return f"b'{contents}'"
392
393
394def _field_repr(field, value) -> str:
395    if field.type == field.TYPE_ENUM:
396        try:
397            enum = field.enum_type.values_by_number[value]
398            return f'{field.enum_type.full_name}.{enum.name}'
399        except KeyError:
400            return repr(value)
401
402    if field.type == field.TYPE_MESSAGE:
403        return proto_repr(value)
404
405    if field.type == field.TYPE_BYTES:
406        return bytes_repr(value)
407
408    return repr(value)
409
410
411def _proto_repr(message) -> Iterator[str]:
412    for field in message.DESCRIPTOR.fields:
413        value = getattr(message, field.name)
414
415        # Skip fields that are not present.
416        try:
417            if not message.HasField(field.name):
418                continue
419        except ValueError:
420            # Skip default-valued fields that don't support HasField.
421            if (
422                field.label != field.LABEL_REPEATED
423                and value == field.default_value
424            ):
425                continue
426
427        if field.label == field.LABEL_REPEATED:
428            if not value:
429                continue
430
431            if isinstance(value, Mapping):
432                key_desc, value_desc = field.message_type.fields
433                values = ', '.join(
434                    f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}'
435                    for k, v in value.items()
436                )
437                yield f'{field.name}={{{values}}}'
438            else:
439                values = ', '.join(_field_repr(field, v) for v in value)
440                yield f'{field.name}=[{values}]'
441        else:
442            yield f'{field.name}={_field_repr(field, value)}'
443
444
445def proto_repr(message, *, wrap: bool = True) -> str:
446    """Creates a repr-like string for a protobuf.
447
448    In an interactive console that imports proto objects into the namespace, the
449    output of proto_repr() can be used as Python source to create a proto
450    object.
451
452    Args:
453      message: The protobuf message to format
454      wrap: If true and black is available, the output is wrapped according to
455          PEP8 using black.
456    """
457    raw = f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})'
458
459    if wrap and black is not None and black_mode is not None:
460        return black.format_str(raw, mode=black_mode).strip()
461
462    return raw
463