1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Ast utils for create or update ast node.""" 16from typing import Optional, List, Union 17import sys 18import ast 19from ..api.scoped_value import ScopedValue, ValueType 20 21if sys.version_info >= (3, 9): 22 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 23else: 24 import astunparse 25 26 27class AstModifier(ast.NodeTransformer): 28 """Ast utils for create or update ast node.""" 29 @staticmethod 30 def insert_ast_to_ast(ast_container: Union[ast.AST, list], ast_node: ast.AST, 31 index_ast: Optional[ast.AST] = None, insert_before=True): 32 """ 33 Insert ast node into an ast container. 34 Only support ast.FunctionDef and ast.body with type of list yet. 35 """ 36 if isinstance(ast_container, list): 37 return AstModifier.insert_ast_to_bodies(ast_container, ast_node, index_ast, insert_before) 38 if isinstance(ast_container, ast.FunctionDef): 39 return AstModifier.insert_ast_to_function(ast_container, ast_node, index_ast, insert_before) 40 raise NotImplementedError(f"Insert ast node into {type(ast_container)} is not support yet.") 41 42 @staticmethod 43 def earse_ast_of_control_flow(ast_root_body: list, ast_branch: ast.AST, is_orelse: bool): 44 """ 45 Clear ast in control flow by replace ast nodes to pass. 46 """ 47 if is_orelse: 48 ast_branch.orelse = [] 49 else: 50 ast_branch.body = [ast.Pass()] 51 if len(ast_branch.body) == 1 and isinstance(ast_branch.body[0], ast.Pass) and not ast_branch.orelse: 52 AstModifier.erase_ast_from_bodies(ast_root_body, ast_branch) 53 return True 54 55 @staticmethod 56 def erase_ast_from_function(ast_func: ast.FunctionDef, to_erase: ast.AST) -> bool: 57 """ 58 Erase ast node from ast.FunctionDef. 59 60 Args: 61 ast_func (ast.FunctionDef): From which to search to_erase-node and erase. 62 to_erase (ast.AST): Node to be erased. 63 64 Returns: 65 A bool if to_erase-node been found and been erased. 66 """ 67 return AstModifier.erase_ast_from_bodies(ast_func.body, to_erase) 68 69 @staticmethod 70 def erase_ast_from_bodies(ast_bodies: List[ast.AST], to_erase: ast.AST) -> bool: 71 """Erase ast node from ast bodies.""" 72 for body in ast_bodies: 73 if id(body) == id(to_erase): 74 ast_bodies.remove(body) 75 return True 76 return False 77 78 @staticmethod 79 def erase_func_from_class_by_name(ast_class: ast.ClassDef, func_name: str): 80 """ 81 Erase ast FunctionDef from ast.ClassDef by name. 82 83 Args: 84 ast_class (ast.ClassDef): From which to search to_erase-node and erase. 85 func_name (str): Function name to be erased. 86 """ 87 for body in ast_class.body: 88 if isinstance(body, ast.FunctionDef) and body.name == func_name: 89 ast_class.body.remove(body) 90 91 @staticmethod 92 def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None, 93 insert_before=True) -> ast.AST: 94 """ 95 Insert an ast node into another ast node's body. 96 97 Args: 98 ast_father (ast.AST): Where new ast node to be inserted into. 99 ast_son (ast.AST): An ast node to be inserted in. 100 index_ast ([ast.AST, optional]): An ast_node indicates a position in 'ast_father' where new ast node to be 101 inserted into. Default is None which means append new ast node to body of 'ast_father'. 102 insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast node to be 103 inserted into. Only valid when 'index_ast' is not None. Default is True which means inserting new ast 104 node before 'index_ast'. 105 106 Returns: 107 An instance of ast.AST which has been inserted into 'ast_father'. 108 109 Raises: 110 ValueError: If 'ast_father' has no attribute named 'body'. 111 RuntimeError: If 'index_ast' is not contained in 'ast_father'. 112 """ 113 if not hasattr(ast_father, "body"): 114 raise ValueError("Input ast_father has no attribute body:", type(ast_father)) 115 if index_ast is None: 116 ast_father.body.append(ast_son) 117 ast.fix_missing_locations(ast_father) 118 return ast_son 119 for index in range(0, len(ast_father.body)): 120 if id(ast_father.body[index]) == id(index_ast): 121 if insert_before: 122 ast_father.body.insert(index, ast_son) 123 else: 124 ast_father.body.insert(index + 1, ast_son) 125 ast.fix_missing_locations(ast_father) 126 return ast_son 127 raise RuntimeError("index_ast is not contained in ast_father") 128 129 @staticmethod 130 def insert_class_into_module(ast_mod: ast.Module, ast_class: ast.ClassDef, index_ast: Optional[ast.AST] = None, 131 insert_before=True) -> ast.ClassDef: 132 """ 133 Insert an ast.ClassDef into an ast.Module. 134 135 Args: 136 ast_mod (ast.Module): Where new ast.ClassDef to be inserted into. 137 ast_class (ast.ClassDef): ClassDef to be inserted. 138 index_ast ([ast.AST, optional]): An ast_node indicates a position in 'ast_mod' where new ast.ClassDef node 139 to be inserted into. Default is None which means append new ast.ClassDef into 'ast_mod'. 140 insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.ClassDef node to 141 be inserted into. Only valid when 'index_ast' is not None. Default is True which means inserting new 142 ast.ClassDef before 'index_ast'. 143 144 Returns: 145 An instance of ast.ClassDef which has been inserted into 'ast_mod'. 146 """ 147 return AstModifier.insert_sub_ast(ast_mod, ast_class, index_ast, insert_before) 148 149 @staticmethod 150 def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue, 151 args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None, 152 index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: 153 """ 154 Insert an ast.Assign into an ast.FunctionDef. 155 156 Args: 157 ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into. 158 targets ([ScopedValue]): Targets of ast.Assign. 159 expr (ScopedValue): Func of ast.Call which is value of new ast.Assign. 160 args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign. 161 kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign. 162 index_ast ([ast.AST, optional]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to 163 be inserted into. Default is None which means append new ast.Assign into 'ast_func'. 164 insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be 165 inserted into. Only valid when 'index_ast' is not None. Default is True which means inserting new 166 ast.Assign before 'index_ast'. 167 168 Returns: 169 An instance of ast.Assign which has been inserted into 'ast_func'. 170 171 Raises: 172 RuntimeError: If 'index_ast' is not contained in 'ast_func'. 173 """ 174 assign = AstModifier.create_call_assign(targets, expr, args, kwargs) 175 return AstModifier.insert_ast_to_function(ast_func, assign, index_ast, insert_before) 176 177 @staticmethod 178 def insert_ast_to_function(ast_func: ast.FunctionDef, ast_node: ast.AST, 179 index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: 180 """ 181 Insert an ast into an ast.FunctionDef. 182 183 Args: 184 ast_func (ast.FunctionDef): Where new ast to be inserted into. 185 ast_node (ast.Assign): An instance of ast.AST to be inserted in. 186 index_ast ([ast.AST, optional]): An ast_node indicates a position in 'ast_func' where new ast node to 187 be inserted into. Default is None which means append new ast to 'ast_func'. 188 insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast node to be 189 inserted into. Only valid when 'index_ast' is not None. Default is True which means inserting new 190 ast before 'index_ast'. 191 192 Returns: 193 An instance of ast.Assign which has been inserted into 'ast_func'. 194 195 Raises: 196 RuntimeError: If 'index_ast' is not contained in 'ast_func'. 197 """ 198 # Insert ast at the frontmost position of function body when index_ast is an argument of function 199 arguments: ast.arguments = ast_func.args 200 if index_ast and arguments.args: 201 for arg in arguments.args: 202 if id(arg) == id(index_ast): 203 ast_func.body.insert(0, ast_node) 204 ast.fix_missing_locations(ast_func) 205 return ast_node 206 # Insert ast at position specified by index_ast in function body 207 ast_node = AstModifier.insert_ast_to_bodies(ast_func.body, ast_node, index_ast, insert_before) 208 ast.fix_missing_locations(ast_node) 209 return ast_node 210 211 @staticmethod 212 def insert_ast_to_bodies(ast_bodies: List[ast.AST], ast_node: ast.AST, 213 index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: 214 """Insert ast at position specified by index_ast of ast_bodies""" 215 # Append ast_assign to ast_bodies when index_ast is None 216 if index_ast is None: 217 ast_bodies.append(ast_node) 218 return ast_node 219 # Append ast_assign to ast_bodies 220 for index, body in enumerate(ast_bodies): 221 if id(body) == id(index_ast): 222 if not insert_before: 223 index += 1 224 ast_bodies.insert(index, ast_node) 225 ast.fix_missing_locations(body) 226 break 227 else: 228 raise ValueError(f"insert position ({'before' if insert_before else 'after'} " 229 f"{astunparse.unparse(index_ast).strip()}) is not contained in ast_bodies") 230 return ast_node 231 232 @staticmethod 233 def append_arg_to_function(ast_func: ast.FunctionDef, ast_arg: ast.arg) -> ast.AST: 234 """ 235 Append an ast.arg to an ast.FunctionDef (e.g. self.construct). 236 237 Args: 238 ast_func (ast.FunctionDef): An instance of ast.FunctionDef which is "construct" function of network. 239 ast_arg (ast.arg): An instance of ast.arg to be inserted in. 240 241 Returns: 242 An instance of ast.arg which has been appended to 'ast_func'. 243 244 Raises: 245 RuntimeError: If 'ast_arg' is not an instance of ast_arg. 246 """ 247 if not isinstance(ast_arg, ast.arg): 248 raise RuntimeError("ast_arg should be an instance of ast.arg.") 249 arguments: ast.arguments = ast_func.args 250 args: [ast.arg] = arguments.args 251 args.append(ast_arg) 252 defaults = arguments.defaults 253 arg_default = ast.Constant(value=None, kind=None) 254 defaults.append(arg_default) 255 return ast_arg 256 257 @staticmethod 258 def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue], 259 field: str) -> ast.AST: 260 """ 261 Append an ast.Assign to an ast.FunctionDef which is function named "__init__" in network. Value of new 262 ast.Assign is an ast.Call represents get an object from global_vars dict. 263 264 While user inserting a custom op, the instance of new custom op is saved in a dict named global_vars. Rewrite 265 need to get the custom op instance from global_vars in new "__init__" function of network: 266 self.var1 = global_vars.get("var1") 267 268 Args: 269 init_func (ast.FunctionDef): An instance of ast.FunctionDef which is "__init__" function of network. 270 targets ([ScopedValue]): Targets of ast.Assign. 271 field (str): A string represents name of new custom op field. 272 273 Returns: 274 An instance of ast.Assign which has been appended to 'init_func'. 275 """ 276 return AstModifier.insert_assign_to_function(init_func, targets=targets, 277 expr=ScopedValue(ValueType.NamingValue, "", "getattr"), 278 args=[ScopedValue(ValueType.NamingValue, "obj"), 279 ScopedValue.create_variable_value(field)]) 280 281 282 @staticmethod 283 def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue], 284 kwargs: {str, ScopedValue}) -> ast.Assign: 285 """ 286 Create an instance of ast.Assign whose value must ba a ast.Call. 287 288 Args: 289 targets ([ScopedValue]): Targets of ast.Assign. 290 expr (ScopedValue): Func of ast.Call which is value of new ast.Assign. 291 args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign. 292 kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign. 293 294 Returns: 295 An instance of ast.Assign. 296 297 Raises: 298 RuntimeError: If 'targets' is None. 299 RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue. 300 301 """ 302 if targets is None: 303 raise RuntimeError("'Targets should not be None.") 304 targets_list = [] 305 for target in targets: 306 if target.type != ValueType.NamingValue: 307 raise RuntimeError("Target must be a right-value, got: ", target) 308 if target.scope: 309 ast_target = ast.Attribute(ast.Name(target.scope, ast.Load()), target.value, ast.Store()) 310 else: 311 ast_target = ast.Name(target.value, ast.Store()) 312 targets_list.append(ast_target) 313 call = AstModifier.create_call(expr, args, kwargs) 314 315 if len(targets) == 1: 316 result = ast.Assign(targets=[targets_list[0]], value=call) 317 elif len(targets) > 1: 318 ast_targets = ast.Tuple(elts=targets_list, ctx=ast.Store()) 319 result = ast.Assign(targets=[ast_targets], value=call) 320 else: 321 raise ValueError(f"For '{astunparse.unparse(call).strip()}', targets should not be empty, but got " 322 f"{targets}, len(targets) is {len(targets)}") 323 ast.fix_missing_locations(result) 324 return result 325 326 @staticmethod 327 def _create_arg_by_constant_value(value: ScopedValue): 328 """ 329 Create an instance of ast.Constant. 330 331 Args: 332 value (ScopedValue): value used to create arg. 333 334 Raises: 335 RuntimeError: if scope of value is not empty. 336 TypeError: type of arg is not ValueType.ConstantValue 337 338 Returns: 339 ast.Constant: An instance of ast.Constant 340 """ 341 if value.type == ValueType.ConstantValue: 342 if value.scope: 343 raise RuntimeError("For arg the scope should be empty") 344 return ast.Constant(value=value.value, kind=None) 345 raise TypeError("Type of arg only support ValueType.ConstantValue, but got {type(value)}") 346 347 @staticmethod 348 def _create_list_or_tuple(value: ScopedValue): 349 """ 350 Create an instance of ast.List or ast.Tuple. 351 352 Args: 353 value (ScopedValue): value used to create ast node. 354 355 Returns: 356 ast.List or ast.Tuple: An instance of ast.List or ast.Tuple. 357 """ 358 elts = [] 359 for v in value.value: 360 elts.append(AstModifier._create_arg_by_constant_value(v)) 361 if isinstance(value, list): 362 return ast.List(elts=elts) 363 return ast.Tuple(elts=elts) 364 365 @staticmethod 366 def _create_keyword(arg: str, value: ScopedValue): 367 """ 368 Create an instance of ast.keyword. 369 370 Args: 371 arg (str): key of keyword. 372 value (ScopedValue): value used to create ast.keywrod instance. 373 374 Raises: 375 RuntimeError: if scope of value is not empty. 376 TypeError: type of arg is not ValueType.ConstantValue 377 378 Returns: 379 ast.keyword: a instance of ast.keyword. 380 """ 381 if value.scope: 382 raise RuntimeError("value.scope should be empty") 383 if value.type == ValueType.ConstantValue: 384 v = ast.Constant(value=value.value, kind=None) 385 elif value.type in (ValueType.ListValue, ValueType.TupleValue): 386 v = AstModifier._create_list_or_tuple(value) 387 else: 388 raise TypeError("Type of keyword value only support [ValueType.ConstantValue, ValueType.ListValue, " 389 f"ValueType.TupleValue], but got {type(value)}") 390 return ast.keyword(arg=arg, value=v) 391 392 @staticmethod 393 def _create_call_args(args: [ScopedValue]) -> [ast.AST]: 394 """ 395 Create a list of ast.AST as args of ast.Call from a list of `ScopedValue`. 396 397 Args: 398 args (list[ScopedValue]): Args of ast.Call. 399 400 Returns: 401 A list of ast.AST as args of ast.Call. 402 403 Raises: 404 RuntimeError: If element of 'args' is not an instance of `ScopedValue`. 405 RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`. 406 """ 407 408 if args is None: 409 return [] 410 results = [] 411 for arg in args: 412 if not isinstance(arg, ScopedValue): 413 raise TypeError("arg should be ScopedValue, got: ", type(arg)) 414 if arg.type == ValueType.ConstantValue: 415 results.append(AstModifier._create_arg_by_constant_value(arg)) 416 elif arg.type == ValueType.NamingValue: 417 if arg.scope: 418 results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store())) 419 else: 420 results.append(ast.Name(arg.value, ast.Store())) 421 elif arg.type in (ValueType.ListValue, ValueType.TupleValue): 422 results.append(AstModifier._create_list_or_tuple(arg)) 423 else: 424 raise RuntimeError("Please handle custom-object first") 425 return results 426 427 @staticmethod 428 def _create_call_kwargs(kwargs: {str: ScopedValue}) -> [ast.keyword]: 429 """ 430 Create a list of ast.keyword as kwargs of ast.Call from a dict of string to `ScopedValue`. 431 432 Args: 433 kwargs (dict{str: ScopedValue}): Kwargs of ast.Call. 434 435 Returns: 436 A list of ast.AST as args of ast.Call. 437 438 Raises: 439 RuntimeError: If element of 'args' is not an instance of `ScopedValue`. 440 RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`. 441 """ 442 443 if kwargs is None: 444 return [] 445 results = [] 446 for arg, value in kwargs.items(): 447 if not isinstance(value, ScopedValue): 448 raise TypeError("value should be ScopedValue, got: ", type(value)) 449 if value.type in (ValueType.ConstantValue, ValueType.ListValue, ValueType.TupleValue): 450 results.append(AstModifier._create_keyword(arg, value)) 451 elif value.type == ValueType.NamingValue: 452 if value.scope: 453 results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()), 454 value.value, ast.Store()))) 455 else: 456 results.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store()))) 457 else: 458 raise RuntimeError("Please handle custom-object first") 459 return results 460 461 @staticmethod 462 def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call: 463 """ 464 Create an instance of ast.Call. 465 466 Args: 467 expr (ScopedValue): Func of ast.Call. 468 args ([ScopedValue]): Args of ast.Call. 469 kwargs ({str, ScopedValue}): Kwargs of ast.Call. 470 471 Returns: 472 An instance of ast.Call. 473 474 Raises: 475 RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue. 476 RuntimeError: If value_type of 'expr' is not ValueType.NamingValue. 477 TypeError: If expr is not an instance of ScopedValue. 478 """ 479 if not isinstance(expr, ScopedValue): 480 raise TypeError("expr should be ScopedValue, got: ", type(expr)) 481 if expr.type == ValueType.CustomObjValue: 482 raise RuntimeError("Please handle custom-object first") 483 if expr.type != ValueType.NamingValue: 484 raise RuntimeError("Expr must not be a constant, because constant can not been called: ", expr.type) 485 if expr.scope: 486 ast_func = ast.Attribute(ast.Name(expr.scope, ast.Load()), expr.value, ast.Store()) 487 else: 488 ast_func = ast.Name(expr.value, ast.Store()) 489 490 ast_args = AstModifier._create_call_args(args) 491 keywords = AstModifier._create_call_kwargs(kwargs) 492 result = ast.Call(func=ast_func, args=ast_args, keywords=keywords) 493 ast.fix_missing_locations(result) 494 return result 495 496 @staticmethod 497 def get_ast_by_value(scoped_value: ScopedValue, orig_ast_node: ast.AST) -> ast.AST: 498 """ 499 Get ast_node by scoped_value. 500 501 Args: 502 scoped_value (ScopedValue): A value with type of ScopedValue . 503 orig_ast_node (ast.AST): Origin ast node to be used by ScopedValue. 504 505 Raises: 506 TypeError: Input value is not a ScopedValue 507 """ 508 if not isinstance(scoped_value, ScopedValue): 509 raise TypeError("scoped_value should be ScopedValue, got: ", type(scoped_value)) 510 # ast_node will not be changed when scoped_value is the unsupported type 511 if scoped_value.type == ValueType.UnsupportedValue: 512 return orig_ast_node if orig_ast_node else ast.Name(id=scoped_value.value, ctx=ast.Load()) 513 if scoped_value.type == ValueType.ConstantValue: 514 new_ast_node = AstModifier.get_ast_by_constant(scoped_value, orig_ast_node) 515 elif scoped_value.type == ValueType.NamingValue: 516 new_ast_node = AstModifier.get_ast_by_name(scoped_value, orig_ast_node) 517 elif scoped_value.type == ValueType.ListValue: 518 ctx = orig_ast_node.ctx if hasattr(orig_ast_node, "ctx") else ast.Load() 519 new_ast_node = orig_ast_node if isinstance(orig_ast_node, ast.List) else ast.List(elts=[], ctx=ctx) 520 elts = [] 521 for idx, item in enumerate(scoped_value.value): 522 orig_elt_ast = new_ast_node.elts[idx] if len(new_ast_node.elts) > idx else None 523 elts.append(AstModifier.get_ast_by_value(item, orig_elt_ast)) 524 new_ast_node.elts = elts 525 elif scoped_value.type in (ValueType.TupleValue, ValueType.ListValue): 526 new_ast_node = AstModifier.get_ast_by_list(scoped_value, orig_ast_node) 527 elif scoped_value.type == ValueType.DictValue: 528 new_ast_node = AstModifier.get_ast_by_dict(scoped_value, orig_ast_node) 529 else: 530 raise TypeError(f"Type of scoped_value should be one of (ConstantValue, NamingValue, ListValue, " 531 f"DictValue, TupleValue), but got {scoped_value.type}") 532 ast.fix_missing_locations(new_ast_node) 533 return new_ast_node 534 535 @staticmethod 536 def get_ast_by_constant(scoped_value: ScopedValue, orig_ast_node: ast.AST): 537 """Get ast_node by constant value.""" 538 constant_value = scoped_value.value 539 if isinstance(orig_ast_node, ast.Constant): 540 orig_ast_node.value = constant_value 541 return orig_ast_node 542 if isinstance(constant_value, (int, float)) and isinstance(orig_ast_node, ast.Num): 543 orig_ast_node.n = constant_value 544 return orig_ast_node 545 if isinstance(constant_value, str) and isinstance(orig_ast_node, ast.Str): 546 orig_ast_node.s = constant_value 547 return orig_ast_node 548 if isinstance(constant_value, bytes) and isinstance(orig_ast_node, ast.Bytes): 549 orig_ast_node.s = constant_value 550 return orig_ast_node 551 if isinstance(constant_value, (bool, type(None))) and isinstance(orig_ast_node, ast.NameConstant): 552 orig_ast_node.value = constant_value 553 return orig_ast_node 554 return ast.Constant(value=constant_value) 555 556 @staticmethod 557 def get_ast_by_name(scoped_value: ScopedValue, orig_ast_node: ast.AST): 558 """Get ast_node by name value.""" 559 ctx = orig_ast_node.ctx if hasattr(orig_ast_node, "ctx") else ast.Load() 560 # scoped_value doesn't have scope 561 if not scoped_value.scope: 562 if isinstance(orig_ast_node, ast.Name): 563 orig_ast_node.id = scoped_value.value 564 return orig_ast_node 565 return ast.Name(id=scoped_value.value, ctx=ctx) 566 # scoped_value has scope 567 if isinstance(orig_ast_node, ast.Attribute): 568 if isinstance(orig_ast_node.value, ast.Name): 569 orig_ast_node.value.id = scoped_value.scope 570 else: 571 ctx_ = orig_ast_node.value.ctx if hasattr(orig_ast_node.value, "ctx") else ast.Load() 572 orig_ast_node.value = ast.Name(id=scoped_value.scope, ctx=ctx_) 573 orig_ast_node.attr = scoped_value.value 574 return orig_ast_node 575 return ast.Attribute(value=ast.Name(scoped_value.scope, ast.Load()), attr=scoped_value.value, ctx=ctx) 576 577 @staticmethod 578 def get_ast_by_list(scoped_value: ScopedValue, orig_ast_node: ast.AST): 579 """Get ast_node by scoped_value with type of TupleValue or ListValue.""" 580 ctx = orig_ast_node.ctx if hasattr(orig_ast_node, "ctx") else ast.Load() 581 if scoped_value.type == ValueType.TupleValue: 582 new_ast_node = orig_ast_node if isinstance(orig_ast_node, ast.Tuple) else ast.Tuple(elts=[], ctx=ctx) 583 else: 584 new_ast_node = orig_ast_node if isinstance(orig_ast_node, ast.List) else ast.List(elts=[], ctx=ctx) 585 elts = [] 586 for idx, item in enumerate(scoped_value.value): 587 orig_elt_ast = new_ast_node.elts[idx] if len(new_ast_node.elts) > idx else None 588 elts.append(AstModifier.get_ast_by_value(item, orig_elt_ast)) 589 new_ast_node.elts = elts 590 return new_ast_node 591 592 @staticmethod 593 def get_ast_by_dict(scoped_value: ScopedValue, orig_ast_node: ast.AST): 594 """Get ast_node by scoped_value with type of DictValue.""" 595 new_ast_node = orig_ast_node if isinstance(orig_ast_node, ast.Dict) else ast.Dict(keys=[], values=[]) 596 keys = [] 597 values = [] 598 for idx, (key, value) in enumerate(scoped_value.value.items()): 599 orig_key_ast = new_ast_node.keys[idx] if len(new_ast_node.keys) > idx else None 600 orig_value_ast = new_ast_node.values[idx] if len(new_ast_node.values) > idx else None 601 keys.append(AstModifier.get_ast_by_value(key, orig_key_ast)) 602 values.append(AstModifier.get_ast_by_value(value, orig_value_ast)) 603 new_ast_node.keys = keys 604 new_ast_node.values = values 605 return new_ast_node 606