1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Find specific type ast node in specific scope.""" 16 17from typing import Type, Any 18import ast 19import copy 20import sys 21if sys.version_info >= (3, 9): 22 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 23else: 24 import astunparse 25 26 27class AstFinder(ast.NodeVisitor): 28 """ 29 Find all specific type ast node in specific scope. 30 31 Args: 32 node (ast.AST): An instance of ast node as search scope. 33 """ 34 35 def __init__(self, node: ast.AST): 36 self._scope: ast.AST = node 37 self._targets: tuple = () 38 self._results: [ast.AST] = [] 39 40 def generic_visit(self, node): 41 """ 42 An override method, iterating over all nodes and save target ast nodes. 43 44 Args: 45 node (ast.AST): An instance of ast node which is visited currently. 46 """ 47 48 if isinstance(node, self._targets): 49 self._results.append(node) 50 super(AstFinder, self).generic_visit(node) 51 52 def find_all(self, ast_types) -> [ast.AST]: 53 """ 54 Find all matched ast node. 55 56 Args: 57 ast_types (Union[tuple(Type), Type]): A tuple of Type or a Type indicates target ast node type. 58 59 Returns: 60 A list of instance of ast.AST as matched result. 61 62 Raises: 63 ValueError: If input `ast_types` is not a type nor a tuple. 64 """ 65 66 if isinstance(ast_types, Type): 67 self._targets: tuple = (ast_types,) 68 else: 69 if not isinstance(ast_types, tuple): 70 raise ValueError("Input ast_types should be a tuple or a type") 71 self._targets: tuple = ast_types 72 73 self._results.clear() 74 self.visit(self._scope) 75 return self._results 76 77 78class StrChecker(ast.NodeVisitor): 79 """ 80 Check if specific String exists in specific scope. 81 82 Args: 83 node (ast.AST): An instance of ast node as check scope. 84 """ 85 86 def __init__(self, node: ast.AST): 87 self._context = node 88 self._pattern = "" 89 self._hit = False 90 91 def visit_Attribute(self, node: ast.Attribute) -> Any: 92 """Visit a node of type ast.Attribute.""" 93 if isinstance(node.value, ast.Name) and node.value.id == self._pattern: 94 self._hit = True 95 return super(StrChecker, self).generic_visit(node) 96 97 def visit_Name(self, node: ast.Name) -> Any: 98 """Visit a node of type ast.Name.""" 99 if node.id == self._pattern: 100 self._hit = True 101 return super(StrChecker, self).generic_visit(node) 102 103 def generic_visit(self, node: ast.AST) -> Any: 104 for _, value in ast.iter_fields(node): 105 if self._hit: 106 break 107 if isinstance(value, (list, tuple)): 108 for item in value: 109 if isinstance(item, ast.AST): 110 self.visit(item) 111 if self._hit: 112 break 113 elif isinstance(value, dict): 114 for item in value.values(): 115 if isinstance(item, ast.AST): 116 self.visit(item) 117 if self._hit: 118 break 119 elif isinstance(value, ast.AST): 120 self.visit(value) 121 122 def check(self, pattern: str) -> bool: 123 """ 124 Check if `pattern` exists in `_context`. 125 126 Args: 127 pattern (str): A string indicates target pattern. 128 129 Returns: 130 A bool indicates if `pattern` exists in `_context`. 131 """ 132 self._pattern = pattern 133 self._hit = False 134 self.generic_visit(self._context) 135 return self._hit 136 137 138class FindConstValueInInit(ast.NodeVisitor): 139 """ 140 Check if specific String exists in specific scope. 141 142 Args: 143 node (ast.AST): An instance of ast node as check scope. 144 """ 145 def __init__(self, node: ast.AST): 146 self._context = node 147 self._pattern = "" 148 self._hit = False 149 150 def visit_Constant(self, node: ast.Constant): 151 if node.value == self._pattern: 152 self._hit = True 153 return node 154 155 def check(self, pattern: str) -> bool: 156 """ 157 Check if `pattern` exists in `_context`. 158 159 Args: 160 pattern (str): A string indicates target pattern. 161 162 Returns: 163 A bool indicates if `pattern` exists in `_context`. 164 """ 165 self._pattern = pattern 166 self._hit = False 167 self.generic_visit(self._context) 168 return self._hit 169 170 171class CheckPropertyIsUsed(ast.NodeVisitor): 172 """ 173 Check whether a property is used. 174 175 Args: 176 node (ast.AST): An instance of ast node. 177 """ 178 def __init__(self, node: ast.AST): 179 self._context = node 180 self._value = "" 181 self._attr = "" 182 self._hit = False 183 184 def visit_Attribute(self, node: ast.Attribute) -> Any: # pylint: disable=invalid-name 185 """Visit a node of type ast.Attribute.""" 186 if isinstance(node.value, ast.Name) and node.value.id == self._value and node.attr == self._attr: 187 self._hit = True 188 return super(CheckPropertyIsUsed, self).generic_visit(node) 189 190 def generic_visit(self, node: ast.AST) -> Any: 191 """ 192 An override method, iterating over all nodes and save target ast nodes. 193 """ 194 if self._hit: 195 return 196 super(CheckPropertyIsUsed, self).generic_visit(node) 197 198 def check(self, value, attr) -> bool: 199 """ 200 Check whether `value` and `attr` exists. 201 """ 202 self._value = value 203 self._attr = attr 204 self._hit = False 205 self.generic_visit(self._context) 206 return self._hit 207 208 209class GetPropertyOfObj(ast.NodeVisitor): 210 """ 211 Check whether a property is used. 212 213 Args: 214 node (ast.AST): An instance of ast node. 215 """ 216 def __init__(self, node: ast.AST): 217 self._context = node 218 self._property = set() 219 220 def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name 221 """Visit a node of type ast.Attribute.""" 222 target = node.targets[0] 223 if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == "self": 224 self._property.add(target.attr) 225 return super(GetPropertyOfObj, self).generic_visit(node) 226 227 def get(self): 228 """ 229 Check whether `value` and `attr` exists. 230 """ 231 self._property = set() 232 self.generic_visit(self._context) 233 return self._property 234 235 236class AstAssignFinder(ast.NodeVisitor): 237 """ 238 Get assign definition ast of specifical parameter in specific scope. 239 240 Args: 241 node (ast.AST): An instance of ast node as check scope. 242 """ 243 def __init__(self, node: ast.AST): 244 self._context = node 245 self._scope = "" 246 self._value = "" 247 self._target = None 248 249 def visit_Assign(self, node: ast.Assign): 250 if self._scope and isinstance(node.targets[0], ast.Attribute): 251 if node.targets[0].attr == self._value and isinstance(node.targets[0].value, ast.Name) \ 252 and node.targets[0].value.id == self._scope: 253 self._target = node 254 elif not self._scope and isinstance(node.targets[0], ast.Name): 255 if node.targets[0].id == self._value: 256 self._target = node 257 258 def get_ast(self, value: str, scope: str = "") -> bool: 259 """ 260 Get assign ast of specifical parameter in specific ast. 261 262 Args: 263 value (str): A string indicates assign target value. 264 scope (str): A string indicates assign target scope. 265 266 Returns: 267 An assign ast with the same target name as `scope.value` . 268 """ 269 self._scope = scope 270 self._value = value 271 self.generic_visit(self._context) 272 return self._target 273 274 275class AstClassFinder(ast.NodeVisitor): 276 """ 277 Find all specific name of ast class node in specific scope. 278 279 Args: 280 node (ast.AST): An instance of ast node as search scope. 281 """ 282 283 def __init__(self, node: ast.AST): 284 self._scope: ast.AST = node 285 self._target: str = "" 286 self._results: [ast.ClassDef] = [] 287 288 def visit_ClassDef(self, node): 289 """ 290 An override method, iterating over all ClassDef nodes and save target ast nodes. 291 292 Args: 293 node (ast.AST): An instance of ast node which is visited currently. 294 """ 295 296 if node.name == self._target: 297 self._results.append(node) 298 299 def find_all(self, class_name: str) -> [ast.AST]: 300 """ 301 Find all matched ast node. 302 303 Args: 304 class_name (str): Name of class to be found. 305 306 Returns: 307 A list of instance of ast.ClassDef as matched result. 308 309 Raises: 310 TypeError: If input `class_name` is not str. 311 """ 312 if not isinstance(class_name, str): 313 raise TypeError("Input class_name should be a str") 314 self._target = class_name 315 self._results.clear() 316 self.visit(self._scope) 317 return self._results 318 319 320class AstFunctionFinder(ast.NodeVisitor): 321 """ 322 Find all specific name of ast function node in specific scope. 323 324 Args: 325 node (ast.AST): An instance of ast node as search scope. 326 """ 327 328 def __init__(self, node: ast.AST): 329 self._scope: ast.AST = node 330 self._target: str = "" 331 self._results: [ast.ClassDef] = [] 332 333 def visit_FunctionDef(self, node): 334 """ 335 An override method, iterating over all FunctionDef nodes and save target ast nodes. 336 337 Args: 338 node (ast.AST): An instance of ast node which is visited currently. 339 """ 340 341 if node.name == self._target: 342 self._results.append(node) 343 344 def find_all(self, func_name: str) -> [ast.AST]: 345 """ 346 Find all matched ast node. 347 348 Args: 349 func_name (str): Name of function to be found. 350 351 Returns: 352 A list of instance of ast.FunctionDef as matched result. 353 354 Raises: 355 TypeError: If input `func_name` is not str. 356 """ 357 if not isinstance(func_name, str): 358 raise TypeError("Input func_name should be a str") 359 self._target = func_name 360 self._results.clear() 361 self.visit(self._scope) 362 return self._results 363 364 365class AstImportFinder(ast.NodeVisitor): 366 """Find all import nodes from input ast node.""" 367 def __init__(self, ast_root: ast.AST): 368 self.ast_root = ast_root 369 self.import_nodes = [] 370 self.try_nodes = [] 371 self.imports_str = [] 372 373 def visit_Try(self, node: ast.Try) -> Any: 374 if isinstance(node.body[0], (ast.Import, ast.ImportFrom)): 375 self.try_nodes.append(copy.deepcopy(node)) 376 return node 377 378 def visit_Import(self, node: ast.Import) -> Any: 379 """Iterate over all nodes and save ast.Import nodes.""" 380 self.import_nodes.append(copy.deepcopy(node)) 381 self.imports_str.append(astunparse.unparse(node)) 382 return node 383 384 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: 385 """Iterate over all nodes and save ast.ImportFrom nodes.""" 386 self.import_nodes.append(copy.deepcopy(node)) 387 self.imports_str.append(astunparse.unparse(node)) 388 return node 389 390 def _remove_duplicated_import_in_try(self, node: [ast.Import, ast.ImportFrom]): 391 import_str = astunparse.unparse(node) 392 if import_str in self.imports_str: 393 self.import_nodes.remove(self.import_nodes[self.imports_str.index(import_str)]) 394 395 def get_import_node(self): 396 self.generic_visit(self.ast_root) 397 for try_node in self.try_nodes: 398 for body in try_node.body: 399 self._remove_duplicated_import_in_try(body) 400 for handler in try_node.handlers: 401 for body in handler.body: 402 self._remove_duplicated_import_in_try(body) 403 self.import_nodes.extend(self.try_nodes) 404 return self.import_nodes 405