• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Parses Bazel rules from a local Bazel workspace."""
15
16import json
17import logging
18import os
19import subprocess
20
21from io import StringIO
22from pathlib import Path, PurePosixPath
23from typing import (
24    Any,
25    Callable,
26    IO,
27    Iterable,
28    Iterator,
29)
30from xml.etree import ElementTree
31
32_LOG = logging.getLogger(__name__)
33
34
35BazelValue = bool | int | str | list[str] | dict[str, str]
36
37
38class ParseError(Exception):
39    """Raised when a Bazel query returns data that can't be parsed."""
40
41
42def parse_invalid(attr: dict[str, Any]) -> BazelValue:
43    """Raises an error that a type is unrecognized."""
44    attr_type = attr['type']
45    raise ParseError(f'unknown type: {attr_type}, expected one of {BazelValue}')
46
47
48class BazelRule:
49    """Represents a Bazel rule as parsed from the query results."""
50
51    def __init__(self, label: str, kind: str) -> None:
52        """Create a Bazel rule.
53
54        Args:
55            label: An absolute Bazel label corresponding to this rule.
56            kind: The type of Bazel rule, e.g. cc_library.
57        """
58        if not label.startswith('//'):
59            raise ParseError(f'invalid label: {label}')
60        if ':' in label:
61            parts = label.split(':')
62            if len(parts) != 2:
63                raise ParseError(f'invalid label: {label}')
64            self._package = parts[0][2:]
65            self._target = parts[1]
66        else:
67            self._package = str(label)[2:]
68            self._target = PurePosixPath(label).name
69        self._kind = kind
70
71        self._attrs: dict[str, BazelValue] = {}
72
73    def package(self) -> str:
74        """Returns the package portion of this rule's name."""
75        return self._package
76
77    def label(self) -> str:
78        """Returns this rule's full target name."""
79        return f'//{self._package}:{self._target}'
80
81    def kind(self) -> str:
82        """Returns this rule's target type."""
83        return self._kind
84
85    def parse(self, attrs: Iterable[dict[str, Any]]) -> None:
86        """Maps JSON data from a bazel query into this object.
87
88        Args:
89            attrs: A dictionary of attribute names and values for the Bazel
90                rule. These should match the output of
91                `bazel cquery ... --output=jsonproto`.
92        """
93        attr_parsers: dict[str, Callable[[dict[str, Any]], BazelValue]] = {
94            'boolean': lambda attr: attr.get('booleanValue', False),
95            'integer': lambda attr: int(attr.get('intValue', '0')),
96            'string': lambda attr: attr.get('stringValue', ''),
97            'string_list': lambda attr: attr.get('stringListValue', []),
98            'label_list': lambda attr: attr.get('stringListValue', []),
99            'string_dict': lambda attr: {
100                p['key']: p['value'] for p in attr.get('stringDictValue', [])
101            },
102        }
103        for attr in attrs:
104            if 'explicitlySpecified' not in attr:
105                continue
106            if not attr['explicitlySpecified']:
107                continue
108            try:
109                attr_name = attr['name']
110            except KeyError:
111                raise ParseError(
112                    f'missing "name" in {json.dumps(attr, indent=2)}'
113                )
114            try:
115                attr_type = attr['type'].lower()
116            except KeyError:
117                raise ParseError(
118                    f'missing "type" in {json.dumps(attr, indent=2)}'
119                )
120
121            attr_parser = attr_parsers.get(attr_type, parse_invalid)
122            self._attrs[attr_name] = attr_parser(attr)
123
124    def has_attr(self, attr_name: str) -> bool:
125        """Returns whether the rule has an attribute of the given name.
126
127        Args:
128            attr_name: The name of the attribute.
129        """
130        return attr_name in self._attrs
131
132    def get_bool(self, attr_name: str) -> bool:
133        """Gets the value of a boolean attribute.
134
135        Args:
136            attr_name: The name of the boolean attribute.
137        """
138        val = self._attrs.get(attr_name, False)
139        assert isinstance(val, bool)
140        return val
141
142    def get_int(self, attr_name: str) -> int:
143        """Gets the value of an integer attribute.
144
145        Args:
146            attr_name: The name of the integer attribute.
147        """
148        val = self._attrs.get(attr_name, 0)
149        assert isinstance(val, int)
150        return val
151
152    def get_str(self, attr_name: str) -> str:
153        """Gets the value of a string attribute.
154
155        Args:
156            attr_name: The name of the string attribute.
157        """
158        val = self._attrs.get(attr_name, '')
159        assert isinstance(val, str)
160        return val
161
162    def get_list(self, attr_name: str) -> list[str]:
163        """Gets the value of a string list attribute.
164
165        Args:
166            attr_name: The name of the string list attribute.
167        """
168        val = self._attrs.get(attr_name, [])
169        assert isinstance(val, list)
170        return val
171
172    def get_dict(self, attr_name: str) -> dict[str, str]:
173        """Gets the value of a string list attribute.
174
175        Args:
176            attr_name: The name of the string list attribute.
177        """
178        val = self._attrs.get(attr_name, {})
179        assert isinstance(val, dict)
180        return val
181
182    def set_attr(self, attr_name: str, value: BazelValue) -> None:
183        """Sets the value of an attribute.
184
185        Args:
186            attr_name: The name of the attribute.
187            value: The value to set.
188        """
189        self._attrs[attr_name] = value
190
191
192class BazelWorkspace:
193    """Represents a local instance of a Bazel repository.
194
195    Attributes:
196        root: Path to the local instance of a Bazel workspace.
197        packages: Bazel packages mapped to their default visibility.
198        repo: The name of Bazel workspace.
199    """
200
201    def __init__(self, pathname: Path) -> None:
202        """Creates an object representing a Bazel workspace at the given path.
203
204        Args:
205            pathname: Path to the local instance of a Bazel workspace.
206        """
207        self.root: Path = pathname
208        self.packages: dict[str, str] = {}
209        self.repo: str = os.path.basename(pathname)
210
211    def get_rules(self, kind: str) -> Iterator[BazelRule]:
212        """Returns rules matching the given kind, e.g. 'cc_library'."""
213        self._load_packages()
214        results = self._query(
215            'cquery', f'kind({kind}, //...)', '--output=jsonproto'
216        )
217        json_data = json.loads(results)
218        for result in json_data.get('results', []):
219            rule_data = result['target']['rule']
220            target = rule_data['name']
221            rule = BazelRule(target, kind)
222            default_visibility = self.packages[rule.package()]
223            rule.set_attr('visibility', [default_visibility])
224            rule.parse(rule_data['attribute'])
225            yield rule
226
227    def _load_packages(self) -> None:
228        """Scans the workspace for packages and their default visibilities."""
229        if self.packages:
230            return
231        packages = self._query('query', '//...', '--output=package').split('\n')
232        for package in packages:
233            results = self._query(
234                'query', f'buildfiles(//{package}:*)', '--output=xml'
235            )
236            xml_data = ElementTree.fromstring(results)
237            for pkg_elem in xml_data:
238                if not pkg_elem.attrib['name'].startswith(f'//{package}:'):
239                    continue
240                for elem in pkg_elem:
241                    if elem.tag == 'visibility-label':
242                        self.packages[package] = elem.attrib['name']
243
244    def _query(self, *args: str) -> str:
245        """Invokes `bazel cquery` with the given selector."""
246        output = StringIO()
247        self._exec(*args, '--noshow_progress', output=output)
248        return output.getvalue()
249
250    def _exec(self, *args: str, output: IO | None = None) -> None:
251        """Execute a Bazel command in the workspace."""
252        cmdline = ['bazel'] + list(args) + ['--noshow_progress']
253        result = subprocess.run(
254            cmdline,
255            cwd=self.root,
256            capture_output=True,
257        )
258        if not result.stdout:
259            _LOG.error(result.stderr.decode('utf-8'))
260            raise ParseError(f'Failed to query Bazel workspace: {self.root}')
261        if output:
262            output.write(result.stdout.decode('utf-8').strip())
263
264    def run(self, label: str, *args, output: IO | None = None) -> None:
265        """Invokes `bazel run` on the given label.
266
267        Args:
268            label: A Bazel target, e.g. "@repo//package:target".
269            args: Additional options to pass to `bazel`.
270            output: Optional destination for the output of the command.
271        """
272        self._exec('run', label, *args, output=output)
273
274    def revision(self) -> str:
275        try:
276            result = subprocess.run(
277                ['git', 'rev-parse', 'HEAD'],
278                cwd=self.root,
279                check=True,
280                capture_output=True,
281            )
282        except subprocess.CalledProcessError as error:
283            print(error.stderr.decode('utf-8'))
284            raise
285        return result.stdout.decode('utf-8').strip()
286
287    def url(self) -> str:
288        try:
289            result = subprocess.run(
290                ['git', 'remote', 'get-url', 'origin'],
291                cwd=self.root,
292                check=True,
293                capture_output=True,
294            )
295        except subprocess.CalledProcessError as error:
296            print(error.stderr.decode('utf-8'))
297            raise
298        return result.stdout.decode('utf-8').strip()
299