1# Copyright 2023 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Parses Bazel rules from a local Bazel workspace.""" 15 16import json 17import logging 18import os 19import subprocess 20 21from io import StringIO 22from pathlib import Path, PurePosixPath 23from typing import ( 24 Any, 25 Callable, 26 IO, 27 Iterable, 28 Iterator, 29) 30from xml.etree import ElementTree 31 32_LOG = logging.getLogger(__name__) 33 34 35BazelValue = bool | int | str | list[str] | dict[str, str] 36 37 38class ParseError(Exception): 39 """Raised when a Bazel query returns data that can't be parsed.""" 40 41 42def parse_invalid(attr: dict[str, Any]) -> BazelValue: 43 """Raises an error that a type is unrecognized.""" 44 attr_type = attr['type'] 45 raise ParseError(f'unknown type: {attr_type}, expected one of {BazelValue}') 46 47 48class BazelRule: 49 """Represents a Bazel rule as parsed from the query results.""" 50 51 def __init__(self, label: str, kind: str) -> None: 52 """Create a Bazel rule. 53 54 Args: 55 label: An absolute Bazel label corresponding to this rule. 56 kind: The type of Bazel rule, e.g. cc_library. 57 """ 58 if not label.startswith('//'): 59 raise ParseError(f'invalid label: {label}') 60 if ':' in label: 61 parts = label.split(':') 62 if len(parts) != 2: 63 raise ParseError(f'invalid label: {label}') 64 self._package = parts[0][2:] 65 self._target = parts[1] 66 else: 67 self._package = str(label)[2:] 68 self._target = PurePosixPath(label).name 69 self._kind = kind 70 71 self._attrs: dict[str, BazelValue] = {} 72 73 def package(self) -> str: 74 """Returns the package portion of this rule's name.""" 75 return self._package 76 77 def label(self) -> str: 78 """Returns this rule's full target name.""" 79 return f'//{self._package}:{self._target}' 80 81 def kind(self) -> str: 82 """Returns this rule's target type.""" 83 return self._kind 84 85 def parse(self, attrs: Iterable[dict[str, Any]]) -> None: 86 """Maps JSON data from a bazel query into this object. 87 88 Args: 89 attrs: A dictionary of attribute names and values for the Bazel 90 rule. These should match the output of 91 `bazel cquery ... --output=jsonproto`. 92 """ 93 attr_parsers: dict[str, Callable[[dict[str, Any]], BazelValue]] = { 94 'boolean': lambda attr: attr.get('booleanValue', False), 95 'integer': lambda attr: int(attr.get('intValue', '0')), 96 'string': lambda attr: attr.get('stringValue', ''), 97 'string_list': lambda attr: attr.get('stringListValue', []), 98 'label_list': lambda attr: attr.get('stringListValue', []), 99 'string_dict': lambda attr: { 100 p['key']: p['value'] for p in attr.get('stringDictValue', []) 101 }, 102 } 103 for attr in attrs: 104 if 'explicitlySpecified' not in attr: 105 continue 106 if not attr['explicitlySpecified']: 107 continue 108 try: 109 attr_name = attr['name'] 110 except KeyError: 111 raise ParseError( 112 f'missing "name" in {json.dumps(attr, indent=2)}' 113 ) 114 try: 115 attr_type = attr['type'].lower() 116 except KeyError: 117 raise ParseError( 118 f'missing "type" in {json.dumps(attr, indent=2)}' 119 ) 120 121 attr_parser = attr_parsers.get(attr_type, parse_invalid) 122 self._attrs[attr_name] = attr_parser(attr) 123 124 def has_attr(self, attr_name: str) -> bool: 125 """Returns whether the rule has an attribute of the given name. 126 127 Args: 128 attr_name: The name of the attribute. 129 """ 130 return attr_name in self._attrs 131 132 def get_bool(self, attr_name: str) -> bool: 133 """Gets the value of a boolean attribute. 134 135 Args: 136 attr_name: The name of the boolean attribute. 137 """ 138 val = self._attrs.get(attr_name, False) 139 assert isinstance(val, bool) 140 return val 141 142 def get_int(self, attr_name: str) -> int: 143 """Gets the value of an integer attribute. 144 145 Args: 146 attr_name: The name of the integer attribute. 147 """ 148 val = self._attrs.get(attr_name, 0) 149 assert isinstance(val, int) 150 return val 151 152 def get_str(self, attr_name: str) -> str: 153 """Gets the value of a string attribute. 154 155 Args: 156 attr_name: The name of the string attribute. 157 """ 158 val = self._attrs.get(attr_name, '') 159 assert isinstance(val, str) 160 return val 161 162 def get_list(self, attr_name: str) -> list[str]: 163 """Gets the value of a string list attribute. 164 165 Args: 166 attr_name: The name of the string list attribute. 167 """ 168 val = self._attrs.get(attr_name, []) 169 assert isinstance(val, list) 170 return val 171 172 def get_dict(self, attr_name: str) -> dict[str, str]: 173 """Gets the value of a string list attribute. 174 175 Args: 176 attr_name: The name of the string list attribute. 177 """ 178 val = self._attrs.get(attr_name, {}) 179 assert isinstance(val, dict) 180 return val 181 182 def set_attr(self, attr_name: str, value: BazelValue) -> None: 183 """Sets the value of an attribute. 184 185 Args: 186 attr_name: The name of the attribute. 187 value: The value to set. 188 """ 189 self._attrs[attr_name] = value 190 191 192class BazelWorkspace: 193 """Represents a local instance of a Bazel repository. 194 195 Attributes: 196 root: Path to the local instance of a Bazel workspace. 197 packages: Bazel packages mapped to their default visibility. 198 repo: The name of Bazel workspace. 199 """ 200 201 def __init__(self, pathname: Path) -> None: 202 """Creates an object representing a Bazel workspace at the given path. 203 204 Args: 205 pathname: Path to the local instance of a Bazel workspace. 206 """ 207 self.root: Path = pathname 208 self.packages: dict[str, str] = {} 209 self.repo: str = os.path.basename(pathname) 210 211 def get_rules(self, kind: str) -> Iterator[BazelRule]: 212 """Returns rules matching the given kind, e.g. 'cc_library'.""" 213 self._load_packages() 214 results = self._query( 215 'cquery', f'kind({kind}, //...)', '--output=jsonproto' 216 ) 217 json_data = json.loads(results) 218 for result in json_data.get('results', []): 219 rule_data = result['target']['rule'] 220 target = rule_data['name'] 221 rule = BazelRule(target, kind) 222 default_visibility = self.packages[rule.package()] 223 rule.set_attr('visibility', [default_visibility]) 224 rule.parse(rule_data['attribute']) 225 yield rule 226 227 def _load_packages(self) -> None: 228 """Scans the workspace for packages and their default visibilities.""" 229 if self.packages: 230 return 231 packages = self._query('query', '//...', '--output=package').split('\n') 232 for package in packages: 233 results = self._query( 234 'query', f'buildfiles(//{package}:*)', '--output=xml' 235 ) 236 xml_data = ElementTree.fromstring(results) 237 for pkg_elem in xml_data: 238 if not pkg_elem.attrib['name'].startswith(f'//{package}:'): 239 continue 240 for elem in pkg_elem: 241 if elem.tag == 'visibility-label': 242 self.packages[package] = elem.attrib['name'] 243 244 def _query(self, *args: str) -> str: 245 """Invokes `bazel cquery` with the given selector.""" 246 output = StringIO() 247 self._exec(*args, '--noshow_progress', output=output) 248 return output.getvalue() 249 250 def _exec(self, *args: str, output: IO | None = None) -> None: 251 """Execute a Bazel command in the workspace.""" 252 cmdline = ['bazel'] + list(args) + ['--noshow_progress'] 253 result = subprocess.run( 254 cmdline, 255 cwd=self.root, 256 capture_output=True, 257 ) 258 if not result.stdout: 259 _LOG.error(result.stderr.decode('utf-8')) 260 raise ParseError(f'Failed to query Bazel workspace: {self.root}') 261 if output: 262 output.write(result.stdout.decode('utf-8').strip()) 263 264 def run(self, label: str, *args, output: IO | None = None) -> None: 265 """Invokes `bazel run` on the given label. 266 267 Args: 268 label: A Bazel target, e.g. "@repo//package:target". 269 args: Additional options to pass to `bazel`. 270 output: Optional destination for the output of the command. 271 """ 272 self._exec('run', label, *args, output=output) 273 274 def revision(self) -> str: 275 try: 276 result = subprocess.run( 277 ['git', 'rev-parse', 'HEAD'], 278 cwd=self.root, 279 check=True, 280 capture_output=True, 281 ) 282 except subprocess.CalledProcessError as error: 283 print(error.stderr.decode('utf-8')) 284 raise 285 return result.stdout.decode('utf-8').strip() 286 287 def url(self) -> str: 288 try: 289 result = subprocess.run( 290 ['git', 'remote', 'get-url', 'origin'], 291 cwd=self.root, 292 check=True, 293 capture_output=True, 294 ) 295 except subprocess.CalledProcessError as error: 296 print(error.stderr.decode('utf-8')) 297 raise 298 return result.stdout.decode('utf-8').strip() 299