• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2023 The Chromium Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Helper script to use GN's JSON interface to make changes.
6
7AST implementation details:
8  https://gn.googlesource.com/gn/+/refs/heads/main/src/gn/parse_tree.cc
9
10To dump an AST:
11  gn format --dump-tree=json BUILD.gn > foo.json
12"""
13
14from __future__ import annotations
15
16import dataclasses
17import functools
18import json
19import subprocess
20from typing import Callable, Dict, List, Optional, Tuple, TypeVar
21
22NODE_CHILD = 'child'
23NODE_TYPE = 'type'
24NODE_VALUE = 'value'
25
26_T = TypeVar('_T')
27
28
29def _create_location_node(begin_line=1):
30    return {
31        'begin_column': 1,
32        'begin_line': begin_line,
33        'end_column': 2,
34        'end_line': begin_line,
35    }
36
37
38def _wrap(node: dict):
39    kind = node[NODE_TYPE]
40    if kind == 'LIST':
41        return StringList(node)
42    if kind == 'BLOCK':
43        return BlockWrapper(node)
44    return NodeWrapper(node)
45
46
47def _unwrap(thing):
48    if isinstance(thing, NodeWrapper):
49        return thing.node
50    return thing
51
52
53def _find_node(root_node: dict, target_node: dict):
54    def recurse(node: dict) -> Optional[Tuple[dict, int]]:
55        children = node.get(NODE_CHILD)
56        if children:
57            for i, child in enumerate(children):
58                if child is target_node:
59                    return node, i
60                ret = recurse(child)
61                if ret is not None:
62                    return ret
63        return None
64
65    ret = recurse(root_node)
66    if ret is None:
67        raise Exception(
68            f'Node not found: {target_node}\nLooked in: {root_node}')
69    return ret
70
71@dataclasses.dataclass
72class NodeWrapper:
73    """Base class for all wrappers."""
74    node: dict
75
76    @property
77    def node_type(self) -> str:
78        return self.node[NODE_TYPE]
79
80    @property
81    def node_value(self) -> str:
82        return self.node[NODE_VALUE]
83
84    @property
85    def node_children(self) -> List[dict]:
86        return self.node[NODE_CHILD]
87
88    @functools.cached_property
89    def first_child(self):
90        return _wrap(self.node_children[0])
91
92    @functools.cached_property
93    def second_child(self):
94        return _wrap(self.node_children[1])
95
96    def is_list(self):
97        return self.node_type == 'LIST'
98
99    def is_identifier(self):
100        return self.node_type == 'IDENTIFIER'
101
102    def visit_nodes(self, callback: Callable[[dict],
103                                             Optional[_T]]) -> List[_T]:
104        ret = []
105
106        def recurse(root: dict):
107            value = callback(root)
108            if value is not None:
109                ret.append(value)
110                return
111            children = root.get(NODE_CHILD)
112            if children:
113                for child in children:
114                    recurse(child)
115
116        recurse(self.node)
117        return ret
118
119    def set_location_recursive(self, line):
120        def helper(n: dict):
121            loc = n.get('location')
122            if loc:
123                loc['begin_line'] = line
124                loc['end_line'] = line
125
126        self.visit_nodes(helper)
127
128    def add_child(self, node, *, before=None):
129        node = _unwrap(node)
130        if before is None:
131            self.node_children.append(node)
132        else:
133            before = _unwrap(before)
134            parent_node, child_idx = _find_node(self.node, before)
135            parent_node[NODE_CHILD].insert(child_idx, node)
136
137            # Prevent blank lines between |before| and |node|.
138            target_line = before['location']['begin_line']
139            _wrap(node).set_location_recursive(target_line)
140
141    def remove_child(self, node):
142        node = _unwrap(node)
143        parent_node, child_idx = _find_node(self.node, node)
144        parent_node[NODE_CHILD].pop(child_idx)
145
146
147@dataclasses.dataclass
148class BlockWrapper(NodeWrapper):
149    """Wraps a BLOCK node."""
150    def __post_init__(self):
151        assert self.node_type == 'BLOCK'
152
153    def find_assignments(self, var_name=None):
154        def match_fn(node: dict):
155            assignment = AssignmentWrapper.from_node(node)
156            if not assignment:
157                return None
158            if var_name is None or var_name == assignment.variable_name:
159                return assignment
160            return None
161
162        return self.visit_nodes(match_fn)
163
164
165@dataclasses.dataclass
166class AssignmentWrapper(NodeWrapper):
167    """Wraps a =, +=, or -= BINARY node where the LHS is an identifier."""
168    def __post_init__(self):
169        assert self.node_type == 'BINARY'
170
171    @property
172    def variable_name(self):
173        return self.first_child.node_value
174
175    @property
176    def value(self):
177        return self.second_child
178
179    @property
180    def list_value(self):
181        ret = self.second_child
182        assert isinstance(ret, StringList), 'Found: ' + ret.node_type
183        return ret
184
185    @property
186    def operation(self):
187        """The assignment operation. Either "=" or "+="."""
188        return self.node_value
189
190    @property
191    def is_append(self):
192        return self.operation == '+='
193
194    def value_as_string_list(self):
195        return StringList(self.value.node)
196
197    @staticmethod
198    def from_node(node: dict) -> Optional[AssignmentWrapper]:
199        if node.get(NODE_TYPE) != 'BINARY':
200            return None
201        children = node[NODE_CHILD]
202        assert len(children) == 2, (
203            'Binary nodes should have two child nodes, but the node is: '
204            f'{node}')
205        left_child, right_child = children
206        if left_child.get(NODE_TYPE) != 'IDENTIFIER':
207            return None
208        if node.get(NODE_VALUE) not in ('=', '+=', '-='):
209            return None
210        return AssignmentWrapper(node)
211
212    @staticmethod
213    def create(variable_name, value, operation='='):
214        value_node = _unwrap(value)
215        id_node = {
216            'location': _create_location_node(),
217            'type': 'IDENTIFIER',
218            'value': variable_name,
219        }
220        return AssignmentWrapper({
221            'location': _create_location_node(),
222            'child': [id_node, value_node],
223            'type': 'BINARY',
224            'value': operation,
225        })
226
227    @staticmethod
228    def create_list(variable_name, operation='='):
229        return AssignmentWrapper.create(variable_name,
230                                        StringList.create(),
231                                        operation=operation)
232
233
234@dataclasses.dataclass
235class StringList(NodeWrapper):
236    """Wraps a list node that contains only string literals."""
237    def __post_init__(self):
238        assert self.is_list()
239
240        self.literals: List[str] = [
241            x[NODE_VALUE].strip('"') for x in self.node_children
242            if x[NODE_TYPE] == 'LITERAL'
243        ]
244
245    def add_literal(self, value: str):
246        # For lists of deps, gn format will sort entries, but it will not
247        # move entries past comment boundaries. Insert at the front by default
248        # so that if sorting moves the value, and there is a comment boundary,
249        # it will end up before the comment instead of immediately after the
250        # comment (which likely does not apply to it).
251        self.literals.insert(0, value)
252        self.node_children.insert(
253            0, {
254                'location': _create_location_node(),
255                'type': 'LITERAL',
256                'value': f'"{value}"',
257            })
258
259    def remove_literal(self, value: str):
260        self.literals.remove(value)
261        quoted = f'"{value}"'
262        children = self.node_children
263        for i, node in enumerate(children):
264            if node[NODE_VALUE] == quoted:
265                children.pop(i)
266                break
267        else:
268            raise ValueError(f'Did not find child with value {quoted}')
269
270    @staticmethod
271    def create() -> StringList:
272        return StringList({
273            'location': _create_location_node(),
274            'begin_token': '[',
275            'child': [],
276            'end': {
277                'location': _create_location_node(),
278                'type': 'END',
279                'value': ']'
280            },
281            'type': 'LIST',
282        })
283
284
285class Target(NodeWrapper):
286    """Wraps a target node.
287
288    A target node is any function besides "template" with exactly two children:
289      * Child 1: LIST with single string literal child
290      * Child 2: BLOCK
291
292    This does not actually find all targets. E.g. ignores those that use an
293    expression for a name, or that use "target(type, name)".
294    """
295    def __init__(self, function_node: dict, name_node: dict):
296        super().__init__(function_node)
297        self.name_node = name_node
298
299    @property
300    def name(self) -> str:
301        return self.name_node[NODE_VALUE].strip('"')
302
303    # E.g. "android_library"
304    @property
305    def type(self) -> str:
306        return self.node[NODE_VALUE]
307
308    @property
309    def block(self) -> BlockWrapper:
310        block = self.second_child
311        assert isinstance(block, BlockWrapper)
312        return block
313
314    def set_name(self, value):
315        self.name_node[NODE_VALUE] = f'"{value}"'
316
317    @staticmethod
318    def from_node(node: dict) -> Optional[Target]:
319        """Returns a Target if |node| is a target, None otherwise."""
320        if node.get(NODE_TYPE) != 'FUNCTION':
321            return None
322        if node.get(NODE_VALUE) == 'template':
323            return None
324        children = node.get(NODE_CHILD)
325        if not children or len(children) != 2:
326            return None
327        func_params_node, block_node = children
328        if block_node.get(NODE_TYPE) != 'BLOCK':
329            return None
330        if func_params_node.get(NODE_TYPE) != 'LIST':
331            return None
332        param_nodes = func_params_node.get(NODE_CHILD)
333        if param_nodes is None or len(param_nodes) != 1:
334            return None
335        name_node = param_nodes[0]
336        if name_node.get(NODE_TYPE) != 'LITERAL':
337            return None
338        return Target(function_node=node, name_node=name_node)
339
340
341class BuildFile:
342    """Represents the contents of a BUILD.gn file."""
343    def __init__(self, path: str, root_node: dict):
344        self.block = BlockWrapper(root_node)
345        self.path = path
346        self._original_content = json.dumps(root_node)
347
348    def write_changes(self) -> bool:
349        """Returns whether there were any changes."""
350        new_content = json.dumps(self.block.node)
351        if new_content == self._original_content:
352            return False
353        output = subprocess.check_output(
354            ['gn', 'format', '--read-tree=json', self.path],
355            text=True,
356            input=new_content)
357        if 'Wrote rebuilt from json to' not in output:
358            raise Exception('JSON was invalid')
359        return True
360
361    @functools.cached_property
362    def targets(self) -> List[Target]:
363        return self.block.visit_nodes(Target.from_node)
364
365    @functools.cached_property
366    def targets_by_name(self) -> Dict[str, Target]:
367        return {t.name: t for t in self.targets}
368
369    @staticmethod
370    def from_file(path):
371        output = subprocess.check_output(
372            ['gn', 'format', '--dump-tree=json', path], text=True)
373        return BuildFile(path, json.loads(output))
374