• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Find specific type ast node in specific scope."""
16
17from typing import Type, Any
18import ast
19import copy
20import sys
21if sys.version_info >= (3, 9):
22    import ast as astunparse # pylint: disable=reimported, ungrouped-imports
23else:
24    import astunparse
25
26
27class AstFinder(ast.NodeVisitor):
28    """
29    Find all specific type ast node in specific scope.
30
31    Args:
32        node (ast.AST): An instance of ast node as search scope.
33    """
34
35    def __init__(self, node: ast.AST):
36        self._scope: ast.AST = node
37        self._targets: tuple = ()
38        self._results: [ast.AST] = []
39
40    def generic_visit(self, node):
41        """
42        An override method, iterating over all nodes and save target ast nodes.
43
44        Args:
45            node (ast.AST): An instance of ast node which is visited currently.
46        """
47
48        if isinstance(node, self._targets):
49            self._results.append(node)
50        super(AstFinder, self).generic_visit(node)
51
52    def find_all(self, ast_types) -> [ast.AST]:
53        """
54        Find all matched ast node.
55
56        Args:
57            ast_types (Union[tuple(Type), Type]): A tuple of Type or a Type indicates target ast node type.
58
59        Returns:
60            A list of instance of ast.AST as matched result.
61
62        Raises:
63            ValueError: If input `ast_types` is not a type nor a tuple.
64        """
65
66        if isinstance(ast_types, Type):
67            self._targets: tuple = (ast_types,)
68        else:
69            if not isinstance(ast_types, tuple):
70                raise ValueError("Input ast_types should be a tuple or a type")
71            self._targets: tuple = ast_types
72
73        self._results.clear()
74        self.visit(self._scope)
75        return self._results
76
77
78class StrChecker(ast.NodeVisitor):
79    """
80    Check if specific String exists in specific scope.
81
82    Args:
83        node (ast.AST): An instance of ast node as check scope.
84    """
85
86    def __init__(self, node: ast.AST):
87        self._context = node
88        self._pattern = ""
89        self._hit = False
90
91    def visit_Attribute(self, node: ast.Attribute) -> Any:
92        """Visit a node of type ast.Attribute."""
93        if isinstance(node.value, ast.Name) and node.value.id == self._pattern:
94            self._hit = True
95        return super(StrChecker, self).generic_visit(node)
96
97    def visit_Name(self, node: ast.Name) -> Any:
98        """Visit a node of type ast.Name."""
99        if node.id == self._pattern:
100            self._hit = True
101        return super(StrChecker, self).generic_visit(node)
102
103    def generic_visit(self, node: ast.AST) -> Any:
104        for _, value in ast.iter_fields(node):
105            if self._hit:
106                break
107            if isinstance(value, (list, tuple)):
108                for item in value:
109                    if isinstance(item, ast.AST):
110                        self.visit(item)
111                    if self._hit:
112                        break
113            elif isinstance(value, dict):
114                for item in value.values():
115                    if isinstance(item, ast.AST):
116                        self.visit(item)
117                    if self._hit:
118                        break
119            elif isinstance(value, ast.AST):
120                self.visit(value)
121
122    def check(self, pattern: str) -> bool:
123        """
124        Check if `pattern` exists in `_context`.
125
126        Args:
127            pattern (str): A string indicates target pattern.
128
129        Returns:
130            A bool indicates if `pattern` exists in `_context`.
131        """
132        self._pattern = pattern
133        self._hit = False
134        self.generic_visit(self._context)
135        return self._hit
136
137
138class FindConstValueInInit(ast.NodeVisitor):
139    """
140    Check if specific String exists in specific scope.
141
142    Args:
143        node (ast.AST): An instance of ast node as check scope.
144    """
145    def __init__(self, node: ast.AST):
146        self._context = node
147        self._pattern = ""
148        self._hit = False
149
150    def visit_Constant(self, node: ast.Constant):
151        if node.value == self._pattern:
152            self._hit = True
153        return node
154
155    def check(self, pattern: str) -> bool:
156        """
157        Check if `pattern` exists in `_context`.
158
159        Args:
160            pattern (str): A string indicates target pattern.
161
162        Returns:
163            A bool indicates if `pattern` exists in `_context`.
164        """
165        self._pattern = pattern
166        self._hit = False
167        self.generic_visit(self._context)
168        return self._hit
169
170
171class CheckPropertyIsUsed(ast.NodeVisitor):
172    """
173    Check whether a property is used.
174
175    Args:
176        node (ast.AST): An instance of ast node.
177    """
178    def __init__(self, node: ast.AST):
179        self._context = node
180        self._value = ""
181        self._attr = ""
182        self._hit = False
183
184    def visit_Attribute(self, node: ast.Attribute) -> Any:  # pylint: disable=invalid-name
185        """Visit a node of type ast.Attribute."""
186        if isinstance(node.value, ast.Name) and node.value.id == self._value and node.attr == self._attr:
187            self._hit = True
188        return super(CheckPropertyIsUsed, self).generic_visit(node)
189
190    def generic_visit(self, node: ast.AST) -> Any:
191        """
192        An override method, iterating over all nodes and save target ast nodes.
193        """
194        if self._hit:
195            return
196        super(CheckPropertyIsUsed, self).generic_visit(node)
197
198    def check(self, value, attr) -> bool:
199        """
200        Check whether `value` and `attr` exists.
201        """
202        self._value = value
203        self._attr = attr
204        self._hit = False
205        self.generic_visit(self._context)
206        return self._hit
207
208
209class GetPropertyOfObj(ast.NodeVisitor):
210    """
211    Check whether a property is used.
212
213    Args:
214        node (ast.AST): An instance of ast node.
215    """
216    def __init__(self, node: ast.AST):
217        self._context = node
218        self._property = set()
219
220    def visit_Assign(self, node: ast.Assign) -> Any:  # pylint: disable=invalid-name
221        """Visit a node of type ast.Attribute."""
222        target = node.targets[0]
223        if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == "self":
224            self._property.add(target.attr)
225        return super(GetPropertyOfObj, self).generic_visit(node)
226
227    def get(self):
228        """
229        Check whether `value` and `attr` exists.
230        """
231        self._property = set()
232        self.generic_visit(self._context)
233        return self._property
234
235
236class AstAssignFinder(ast.NodeVisitor):
237    """
238    Get assign definition ast of specifical parameter in specific scope.
239
240    Args:
241        node (ast.AST): An instance of ast node as check scope.
242    """
243    def __init__(self, node: ast.AST):
244        self._context = node
245        self._scope = ""
246        self._value = ""
247        self._target = None
248
249    def visit_Assign(self, node: ast.Assign):
250        if self._scope and isinstance(node.targets[0], ast.Attribute):
251            if node.targets[0].attr == self._value and isinstance(node.targets[0].value, ast.Name) \
252                and node.targets[0].value.id == self._scope:
253                self._target = node
254        elif not self._scope and isinstance(node.targets[0], ast.Name):
255            if node.targets[0].id == self._value:
256                self._target = node
257
258    def get_ast(self, value: str, scope: str = "") -> bool:
259        """
260        Get assign ast of specifical parameter in specific ast.
261
262        Args:
263            value (str): A string indicates assign target value.
264            scope (str): A string indicates assign target scope.
265
266        Returns:
267            An assign ast with the same target name as `scope.value` .
268        """
269        self._scope = scope
270        self._value = value
271        self.generic_visit(self._context)
272        return self._target
273
274
275class AstClassFinder(ast.NodeVisitor):
276    """
277    Find all specific name of ast class node in specific scope.
278
279    Args:
280        node (ast.AST): An instance of ast node as search scope.
281    """
282
283    def __init__(self, node: ast.AST):
284        self._scope: ast.AST = node
285        self._target: str = ""
286        self._results: [ast.ClassDef] = []
287
288    def visit_ClassDef(self, node):
289        """
290        An override method, iterating over all ClassDef nodes and save target ast nodes.
291
292        Args:
293            node (ast.AST): An instance of ast node which is visited currently.
294        """
295
296        if node.name == self._target:
297            self._results.append(node)
298
299    def find_all(self, class_name: str) -> [ast.AST]:
300        """
301        Find all matched ast node.
302
303        Args:
304            class_name (str): Name of class to be found.
305
306        Returns:
307            A list of instance of ast.ClassDef as matched result.
308
309        Raises:
310            TypeError: If input `class_name` is not str.
311        """
312        if not isinstance(class_name, str):
313            raise TypeError("Input class_name should be a str")
314        self._target = class_name
315        self._results.clear()
316        self.visit(self._scope)
317        return self._results
318
319
320class AstFunctionFinder(ast.NodeVisitor):
321    """
322    Find all specific name of ast function node in specific scope.
323
324    Args:
325        node (ast.AST): An instance of ast node as search scope.
326    """
327
328    def __init__(self, node: ast.AST):
329        self._scope: ast.AST = node
330        self._target: str = ""
331        self._results: [ast.ClassDef] = []
332
333    def visit_FunctionDef(self, node):
334        """
335        An override method, iterating over all FunctionDef nodes and save target ast nodes.
336
337        Args:
338            node (ast.AST): An instance of ast node which is visited currently.
339        """
340
341        if node.name == self._target:
342            self._results.append(node)
343
344    def find_all(self, func_name: str) -> [ast.AST]:
345        """
346        Find all matched ast node.
347
348        Args:
349            func_name (str): Name of function to be found.
350
351        Returns:
352            A list of instance of ast.FunctionDef as matched result.
353
354        Raises:
355            TypeError: If input `func_name` is not str.
356        """
357        if not isinstance(func_name, str):
358            raise TypeError("Input func_name should be a str")
359        self._target = func_name
360        self._results.clear()
361        self.visit(self._scope)
362        return self._results
363
364
365class AstImportFinder(ast.NodeVisitor):
366    """Find all import nodes from input ast node."""
367    def __init__(self, ast_root: ast.AST):
368        self.ast_root = ast_root
369        self.import_nodes = []
370        self.try_nodes = []
371        self.imports_str = []
372
373    def visit_Try(self, node: ast.Try) -> Any:
374        if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
375            self.try_nodes.append(copy.deepcopy(node))
376        return node
377
378    def visit_Import(self, node: ast.Import) -> Any:
379        """Iterate over all nodes and save ast.Import nodes."""
380        self.import_nodes.append(copy.deepcopy(node))
381        self.imports_str.append(astunparse.unparse(node))
382        return node
383
384    def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
385        """Iterate over all nodes and save ast.ImportFrom nodes."""
386        self.import_nodes.append(copy.deepcopy(node))
387        self.imports_str.append(astunparse.unparse(node))
388        return node
389
390    def _remove_duplicated_import_in_try(self, node: [ast.Import, ast.ImportFrom]):
391        import_str = astunparse.unparse(node)
392        if import_str in self.imports_str:
393            self.import_nodes.remove(self.import_nodes[self.imports_str.index(import_str)])
394
395    def get_import_node(self):
396        self.generic_visit(self.ast_root)
397        for try_node in self.try_nodes:
398            for body in try_node.body:
399                self._remove_duplicated_import_in_try(body)
400            for handler in try_node.handlers:
401                for body in handler.body:
402                    self._remove_duplicated_import_in_try(body)
403        self.import_nodes.extend(self.try_nodes)
404        return self.import_nodes
405