1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""SymbolTree class define of Rewrite according to forward function of a network.""" 16import stat 17from typing import Optional, Union, Tuple, Any, Dict, List 18import types 19import os 20import sys 21import ast 22import importlib.util 23import time 24import inspect 25from textwrap import dedent 26from collections import OrderedDict 27 28from mindspore.nn import Cell 29from mindspore import log as logger 30from .symbol_tree_dumper import SymbolTreeDumper 31from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager 32from ..api.node_type import NodeType 33from ..api.scoped_value import ScopedValue, ValueType 34from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \ 35 AstImportFinder 36from ..common.namer import TargetNamer, NodeNamer, ClassNamer 37from ..common.observer import Observer 38from ..common.observable import Observable 39from ..common.event import Event 40 41if sys.version_info >= (3, 9): 42 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 43else: 44 import astunparse 45 46class Position: 47 """ 48 Position indicates a source code position in one network. 49 50 Rewrite recommend using class method `create()` of position rather than constructor of Position. 51 52 Args: 53 symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree. 54 node (Node): A handler of Node indicated position is around which Node. 55 before_node (bool): A bool indicated position is before or after the 'node'. 56 """ 57 58 def __init__(self, symbol_tree, node, before_node: bool): 59 self.symbol_tree = symbol_tree 60 self.node = node 61 self.before_node = before_node 62 63 @classmethod 64 def create(cls, symbol_tree, node, before_node): 65 """ 66 Class method of Position. Return None when symbol_tree or node is None. 67 68 Args: 69 symbol_tree: A handler of SymbolTree indicated position in which SymbolTree. 70 node: A handler of Node indicated position is around which Node. 71 before_node (bool): A bool indicated position is before or after the 'node'. 72 73 Returns: 74 A Position. 75 """ 76 if symbol_tree is None or node is None: 77 return None 78 return Position(symbol_tree, node, before_node) 79 80 81class FieldFinder(AstFinder): 82 """ 83 Check whether field exist in specific scope. 84 85 Args: 86 scope (ast.AST): An instance of ast node as search scope. 87 """ 88 89 def __init__(self, scope: ast.AST): 90 super().__init__(scope) 91 self._result = False 92 self._field_name = "" 93 94 def visit_Attribute(self, node: ast.Attribute) -> Any: 95 """Visit a node of type ast.Attribute.""" 96 value = node.value 97 if not isinstance(value, ast.Name): 98 return super(FieldFinder, self).generic_visit(node) 99 if value.id != "self": 100 return super(FieldFinder, self).generic_visit(node) 101 if node.attr == self._field_name: 102 self._result = True 103 return super(FieldFinder, self).generic_visit(node) 104 105 def check(self, field) -> bool: 106 """ 107 Check whether `field` exist in scope. 108 109 Args: 110 field (str): A string indicates target field name. 111 112 Returns: 113 A bool indicate whether `field` exist in scope. 114 """ 115 self._result = False 116 self._field_name = field 117 self.visit(self._scope) 118 return self._result 119 120 121class SymbolTree(Observer, Observable, NodeManager): 122 """ 123 A symbol-tree usually corresponding to forward method of a network. 124 125 Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree rather than invoking constructor 126 of SymbolTree directly. 127 128 Args: 129 origin_network (Cell): A handler to original network instance. 130 module_ast (ast.Module): An instance of ast.AST represents ast node of original network. 131 """ 132 # whether parse CallFunction node inserted by user. 133 _unparse_inserted_function = True 134 135 def __init__(self, origin_network: Cell, module_ast: ast.Module): 136 Observer.__init__(self) 137 Observable.__init__(self) 138 self._node_namer = NodeNamer() 139 self._node_namer.add_name('obj') 140 NodeManager.__init__(self) 141 NodeManager.set_manager_node_namer(self, self._node_namer) 142 NodeManager.reg_observer(self, observer=self) 143 # init unique-namers 144 self._target_namer = TargetNamer() 145 # input arguments of function 146 self._ori_cls_name = type(origin_network).__name__ 147 self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name) 148 NodeManager.set_manager_name(self, self._opt_cls_name) 149 self._origin_network = origin_network 150 self._module_ast: ast.Module = module_ast 151 self._import_asts: Optional[ast.Ast] = [] 152 self._class_ast: Optional[ast.ClassDef] = None 153 self._root_ast: Optional[ast.FunctionDef] = None 154 self._init_func_ast: Optional[ast.FunctionDef] = None 155 self._deleted_field = {} 156 self._deleted_node = [] 157 # {ast_function: [import_asts]} 158 self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict() 159 # {ast_class: [import_asts]} 160 self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict() 161 self._modified = False 162 self._saved_file_name = "./network_define.py" 163 # used to insert "sys.path.append(xxx)" 164 self._net_file_paths = [] 165 self._tmp_import_strs = [] 166 self._tmp_unmodified_strees: {type, List[SymbolTree]} = {} 167 self._tmp_replacers = [] 168 # user custom codes 169 self._custom_codes: List[ast.AST] = [] 170 # local primitive instances initialized during forward method, e.g. abs_inst = P.Abs() 171 self._local_prim_inits: List[Node] = [] 172 173 @staticmethod 174 def _remove_unused_import(module_ast): 175 """remove unused import in self._module_ast""" 176 import_nodes: List[Union[ast.Import, ast.ImportFrom]] = [] 177 178 def is_divider(ast_node): 179 """judge if ast node is divider of new class or function by checking ast.Expr of '#'.""" 180 return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#' 181 182 for ast_node in module_ast.body[:]: 183 if isinstance(ast_node, (ast.Import, ast.ImportFrom)): 184 import_nodes.append(ast_node) 185 if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)): 186 str_checker = StrChecker(ast_node) 187 for import_node in import_nodes: 188 for alias in import_node.names[:]: 189 name = alias.asname if alias.asname else alias.name 190 if name == '*': 191 continue 192 if not str_checker.check(name): 193 import_node.names.remove(alias) 194 if not import_node.names: 195 module_ast.body.remove(import_node) 196 if is_divider(ast_node): 197 import_nodes.clear() 198 199 @staticmethod 200 def _remove_duplicated_import(module_ast): 201 """Remove duplicated import of 'net'.""" 202 imports = set() 203 futures = set() 204 names = set() 205 206 class TransImportNode(ast.NodeTransformer): 207 """Find all import nodes from input ast node.""" 208 209 def visit_ClassDef(self, node: ast.ClassDef) -> Any: 210 if node.name not in names: 211 names.add(node.name) 212 return node 213 return None 214 215 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: 216 if node.name not in names: 217 names.add(node.name) 218 return node 219 return None 220 221 def visit_Try(self, node: ast.Try) -> Any: 222 if isinstance(node.body[0], (ast.Import, ast.ImportFrom)): 223 import_str = astunparse.unparse(node) 224 if import_str not in imports: 225 imports.add(import_str) 226 return node 227 return None 228 229 def visit_Import(self, node: ast.Import) -> Any: 230 import_str = astunparse.unparse(node) 231 if import_str not in imports: 232 imports.add(import_str) 233 return node 234 return None 235 236 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: 237 """ 238 Once the father class 'A' is defined in the current module, all the next imported class 'A' should 239 be removed. e.g. 240 def class A(): 241 ... 242 from xxx import A, B 243 => 244 def class A(): 245 ... 246 from xxx import B 247 """ 248 import_str = astunparse.unparse(node) 249 250 if import_str not in imports: 251 imports.add(import_str) 252 # remove "__future__" module 253 if node.module == '__future__': 254 futures.add(node.module) 255 return None 256 # remove modules which have been defined in the code file 257 # it occurs when class A is a father class and other sub-classes import A 258 for alias in node.names[:]: 259 if alias.name in names: 260 node.names.remove(alias) 261 # if the alias(es) in node.names are all removed, this import statement should be removed 262 if not node.names: 263 return None 264 return node 265 return None 266 267 get_node_handler = TransImportNode() 268 get_node_handler.generic_visit(module_ast) 269 270 @staticmethod 271 def _remove_arg_annotations(module_ast): 272 """Remove annotations in ast.arg to avoid 'xxx is not defined'.""" 273 ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg) 274 for ast_arg in ast_args: 275 ast_arg.annotation = None 276 277 @staticmethod 278 def _check_import(import_path: str, import_module: str): 279 """ 280 Check whether import operation is valid when importing module from specific path. 281 """ 282 if import_path not in sys.path: 283 sys.path.append(import_path) 284 try: 285 importlib.import_module(name=import_module) 286 except (ValueError, ImportError) as e: 287 logger.info(f"Test import {import_module} from {import_path} failed: {e}.") 288 return False 289 except Exception as e: # pylint: disable=W0703 290 logger.info(f"Test import {import_module} from {import_path} failed: {e}.") 291 return False 292 return True 293 294 @staticmethod 295 def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str): 296 """Process relative imports""" 297 file_path = os.path.normcase(file_path) 298 file_path = os.path.normpath(file_path) 299 if isinstance(import_node, ast.ImportFrom): 300 # pad the ImportFrom with parent path 301 # e.g. from ..C import xxx -> from A.B.C import xxx 302 import_module = SymbolTree._get_valid_import_info(import_node, file_path) 303 if import_module: 304 import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0) 305 return import_node 306 307 @staticmethod 308 def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str): 309 """Get valid import info while import_node.module is at form of relative path""" 310 file_path = os.path.dirname(os.path.abspath(file_path)) 311 # get real path from import_node.level 312 # from .(A) import xxx: current path 313 # from ..(A) import xxx: last level path 314 level = import_node.level 315 # from A import xxx: it does not need to pad, directly return the module name 316 if level == 0: 317 return import_node.module 318 if level > 1: 319 for _ in range(level - 1): 320 file_path = os.path.dirname(file_path) 321 file_path_tmp = file_path[:] 322 max_level_count = file_path.count(os.path.sep) - 1 323 level_count = 0 324 # suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx' 325 suffix = '' 326 if import_node.module: 327 suffix = '.' + import_node.module 328 while level_count < max_level_count: 329 file_path_tmp = os.path.dirname(file_path_tmp) 330 if file_path_tmp not in sys.path: 331 logger.debug(f"{file_path_tmp} not in sys.path, try upper level.") 332 level_count += 1 333 continue 334 import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix 335 if SymbolTree._check_import(file_path_tmp, import_module): 336 # try test code success 337 return import_module 338 # test import ast failed, try upper level 339 level_count += 1 340 logger.info(f"Try upper level.") 341 # try codes with all level failed 342 logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.") 343 return None 344 345 @staticmethod 346 def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager): 347 """update ast when inserting NodeType.Input node""" 348 if not isinstance(node_manager, (SymbolTree, CallFunction)): 349 raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of " 350 f"CallFunction, but get {type(node_manager)}") 351 # insert a new input 352 node_manager.get_input_nodes().append(new_node) 353 ast_function: ast.FunctionDef = node_manager.get_manager_ast() 354 arg: str = new_node.get_targets()[0].value 355 ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None) 356 AstModifier.append_arg_to_function(ast_function, ast_arg) 357 358 @staticmethod 359 def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool, 360 node_manager: NodeManager, stree): 361 """update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node""" 362 # create a new assign statement 363 ast_assign = new_node.get_ast() 364 if ast_assign is None: 365 func_name = stree.unique_func_name(new_node.get_name()) 366 new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self")) 367 ast_assign = new_node.update_ast_node() 368 if not isinstance(ast_assign, ast.Assign): 369 raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}") 370 # Save instance into _origin_network. 371 setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance()) 372 # Insert ast to __init__ function 373 if isinstance(new_node, TreeNode): 374 init_code = f"{new_node.get_func_name()} = " \ 375 f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})" 376 else: 377 init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}" 378 init_ast = ast.parse(init_code).body[0] 379 AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast) 380 # Insert ast to construct_function/class_internal_function 381 ast_base_node = base_node.get_ast() if base_node else None 382 ast_node_manager = node_manager.get_manager_ast() 383 if not ast_node_manager: 384 raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} " 385 "when inserting the ast.") 386 AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node) 387 388 @staticmethod 389 def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool, 390 node_manager: NodeManager, stree: 'SymbolTree'): 391 """update ast when inserting NodeType.CallFunction node""" 392 func_name = str(new_node.get_func_name()) 393 # create a new assign statement 394 ast_assign = new_node.get_ast() 395 if ast_assign is None: 396 ast_assign = new_node.update_ast_node() 397 # Insert ast to node_manager 398 ast_base_node = base_node.get_ast() if base_node else None 399 ast_node_manager = node_manager.get_manager_ast() 400 if not ast_node_manager: 401 raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} " 402 "when inserting the ast.") 403 AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node) 404 # Ignore Python builtin functions 405 func_obj = new_node.get_instance() 406 if isinstance(func_obj, types.BuiltinFunctionType): 407 logger.warning(f"Ignore built in function: {func_name}") 408 return 409 # get ast.FunctionDef 410 source_code = inspect.getsource(func_obj) 411 ast_functiondef = ast.parse(dedent(source_code)).body[0] 412 if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef): 413 logger.debug(f"import '{func_name}' to access function object") 414 # add import to make sure that the function object can be accessed. 415 module = inspect.getmodule(func_obj) 416 top_node_manager = node_manager.get_top_manager() 417 belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast() 418 stree.add_import(module, func_name, belonging_ast) 419 return 420 # parse nodes in inserted function. 421 new_node.set_manager_ast(ast_functiondef) 422 new_node.set_manager_node_namer(stree.get_node_namer()) 423 stree.get_external_ast()[ast_functiondef] = [] 424 # import module which function defined in 425 func_file_path = inspect.getabsfile(func_obj) 426 stree.save_imports_from_file(func_file_path, ast_functiondef) 427 # expand ast codes in function 428 from ..ast_helpers import AstFlattener 429 ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree) 430 # parse ast codes into CallFunction Node 431 from ..parsers import ParserRegister 432 parser = ParserRegister.instance().get_parser(ast.FunctionDef) 433 parser.process(stree, ast_functiondef, node_manager=new_node) 434 435 @staticmethod 436 def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool): 437 """ insert_to_ast_while_insert_node. """ 438 stree = new_node.get_belong_symbol_tree() 439 if not stree: 440 raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.") 441 node_manager = new_node.get_node_manager() 442 if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)): 443 raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can " 444 f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}") 445 if new_node.get_node_type() == NodeType.Input: 446 SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager) 447 elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree): 448 SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager, 449 stree) 450 elif new_node.get_node_type() == NodeType.CallFunction: 451 SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree) 452 else: 453 raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be " 454 f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got " 455 f"{new_node.get_node_type()}.") 456 457 @staticmethod 458 def get_node_full_name(node: Node) -> str: 459 """Get full name of node""" 460 name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name() 461 # traverse node_manager with type of Node 462 node_manager = node.get_node_manager() 463 while isinstance(node_manager, Node): 464 name = f"{node_manager.get_manager_name()}.{name}" 465 node_manager = node_manager.get_node_manager() 466 # type of node_manager is SymbolTree now 467 name = f"{node_manager.get_manager_name()}.{name}" 468 return name 469 470 def local_prim_inits(self) -> List[Node]: 471 """get local primitives constructed during forward method""" 472 return self._local_prim_inits 473 474 def finish_build(self): 475 """Add Event.TopologicalChangeEvent event when build is finished.""" 476 self.add_event(Event.TopologicalChangeEvent) 477 478 def get_ori_cls_name(self) -> str: 479 """ 480 Get class name of original network. 481 482 Returns: 483 A str represents class name of original network. 484 """ 485 return self._ori_cls_name 486 487 def get_opt_cls_name(self) -> str: 488 """ 489 Get class name of rewritten network. 490 491 Returns: 492 A str represents class name of rewritten network. 493 """ 494 return self._opt_cls_name 495 496 def get_module_ast(self): 497 """ 498 Getter of `_module_ast`. 499 500 Returns: 501 An instance of ast.AST represents ast node of corresponding module. 502 """ 503 return self._module_ast 504 505 def set_module_ast(self, ast_node: ast.Module): 506 """ 507 Setter of _module_ast. 508 509 Args: 510 ast_node (ast.Module): An instance of ast.Module represents ast node of module of corresponding network 511 class. 512 """ 513 self._module_ast = ast_node 514 515 def get_ast_root(self): 516 """ 517 Getter of `_root_ast`. 518 519 Returns: 520 An instance of ast.AST represents ast node of corresponding forward method. 521 """ 522 return self._root_ast 523 524 def set_ast_root(self, ast_node: ast.FunctionDef): 525 """ 526 Setter of _root_ast. 527 528 Args: 529 ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of 530 corresponding network class. 531 """ 532 self._root_ast = ast_node 533 NodeManager.set_manager_ast(self, ast_node) 534 535 def get_class_ast(self): 536 """ 537 Getter of `_class_ast`. 538 539 Returns: 540 An instance of ast.ClassDef represents ast node of corresponding network class. 541 """ 542 return self._class_ast 543 544 def set_class_ast(self, ast_node: ast.ClassDef): 545 """ 546 Setter of `_class_ast`. 547 548 Args: 549 ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class. 550 """ 551 self._class_ast = ast_node 552 553 def get_init_func_ast(self): 554 """ 555 Getter of _init_func_ast. 556 557 Returns: 558 An instance of ast.FunctionDef represents ast node of init method of corresponding network class. 559 """ 560 return self._init_func_ast 561 562 def set_init_func_ast(self, ast_node: ast.FunctionDef): 563 """ 564 Setter of _init_func_ast. 565 566 Args: 567 ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of 568 corresponding network class. 569 """ 570 self._init_func_ast = ast_node 571 572 def get_origin_network(self): 573 """ 574 Getter of `_origin_network`. 575 576 Returns: 577 An instance of Cell which represents original network. 578 """ 579 return self._origin_network 580 581 def get_nodes_dict(self): 582 """Get dict of nodes""" 583 return self._nodes 584 585 def get_node_namer(self): 586 """Get _node_namer""" 587 return self._node_namer 588 589 def is_modified(self): 590 """ 591 Check whether symbol tree is modified. 592 593 Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after 594 the symbol tree is created. 595 """ 596 return self._modified 597 598 def set_modified_true(self): 599 """ 600 Set self._modified true. 601 602 Self._modified is set true when 'if' exists in the original network. 603 In this situation, different original network instance tends to be different. 604 Hence, the class name should be updated. 605 """ 606 self._modified = True 607 608 def get_import_asts(self): 609 """Get _import_asts""" 610 return self._import_asts 611 612 def get_external_ast(self): 613 """Get _external_ast""" 614 return self._external_ast 615 616 def get_father_class_ast(self): 617 """Get _father_class_ast""" 618 return self._father_class_ast 619 620 def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]: 621 """ 622 Getter of inputs in topological relation of current 'node_or_name'. 623 624 Args: 625 node_or_name (Union[Node, str]): An instance of node or a str represents name of node. 626 627 Returns: 628 A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if 629 'node_or_name' not belong to current SymbolTree. 630 """ 631 632 real_node: Optional[Node] = self._get_real_node(node_or_name) 633 if real_node is None: 634 logger.info("Node(%s) is not belong to current SymbolTree", node_or_name) 635 return [] 636 return node_or_name.get_inputs() 637 638 def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]: 639 """ 640 Getter of outputs in topological relation of current 'node_or_name'. 641 642 Args: 643 node_or_name (Union[Node, str]): An instance of node or a str represents name of node. 644 645 Returns: 646 A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if 647 'node_or_name' not belong to current SymbolTree. 648 """ 649 650 real_node: Optional[Node] = self._get_real_node(node_or_name) 651 if real_node is None: 652 logger.info("Node(%s) is not belong to current SymbolTree", node_or_name) 653 return [] 654 if real_node.get_node_type() == NodeType.Output: 655 return [] 656 node_users = [] 657 for target_users in real_node.get_target_users().values(): 658 if not target_users: 659 continue 660 if target_users not in node_users: 661 node_users.extend(target_users) 662 return node_users 663 664 def before(self, node_or_name: Union[Node, str]) -> Position: 665 """ 666 Get insert position before 'node_or_name' in source code list. 667 Consider using symbol_tree, node and before/after as position for sub-tree feature. 668 669 Note: 670 Topological order is not determined here which is determined by arguments of node and updated by 671 TopologicalManager automatically. 672 673 Args: 674 node_or_name (Union[Node, str]): An instance of node or a str represents name of node. 675 676 Returns: 677 A Position represents an insert point. 678 679 Raises: 680 AssertError: If 'node_or_name' is not a Node or a str 681 RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current 682 SymbolTree. 683 """ 684 685 node = self._get_real_node(node_or_name) 686 if node is None: 687 raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) 688 return Position.create(node.get_belong_symbol_tree(), node, True) 689 690 def after(self, node_or_name: Union[Node, str]) -> Position: 691 """ 692 Get insert position after 'node_or_name' in source code list. 693 Consider using symbol_tree, node and before/after as position for sub-tree feature. 694 695 Note: 696 Topological order is not determined here which is determined by arguments of node and updated by 697 TopologicalManager automatically. 698 699 Args: 700 node_or_name (Union[Node, str]): An instance of node or a str represents name of node. 701 702 Returns: 703 A Position represents an insert point. 704 705 Raises: 706 AssertError: If 'node_or_name' is not a Node or a str 707 RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current 708 SymbolTree. 709 """ 710 node = self._get_real_node(node_or_name) 711 if node is None: 712 raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) 713 return Position.create(node.get_belong_symbol_tree(), node, False) 714 715 def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None, 716 insert_to_ast: bool = True): 717 """ 718 Insert a node before or after base_node. 719 720 Note: 721 Name of node will be unique while inserting node into SymbolTree. 722 723 ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will 724 be saved in global_vars dict while inserting node into SymbolTree. 725 726 Targets of node will be unique while inserting node into SymbolTree. 727 728 A field instantiation statement will be added into "init" function of network class using node name as field 729 name when `insert_to_ast` is True while inserting node into SymbolTree. 730 731 An assign statement represents invoking to this node will be added into forward function of network class 732 corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into 733 SymbolTree. 734 735 Topological relation is updated and inputs of corresponding node is updated. 736 737 Args: 738 new_node (Node): Node to be inserted. 739 base_node (Node): New node will be inserted before or after base_node. 740 before_node (bool): Indicate whether new node is inserted before base_node. 741 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 742 NodeManager of symboltree's construct function. 743 insert_to_ast (bool): Indicate whether ast nodes need to be updated. 744 745 Returns: 746 An instance of node which has been inserted into SymbolTree. 747 748 Raises: 749 ValueError: Node in the SymbolTree is inserted into SymbolTree again. 750 RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True. 751 """ 752 if new_node.get_belong_symbol_tree(): 753 raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}") 754 755 # Check if base_node in current SymbolTree 756 if base_node is not None: 757 stree = base_node.get_belong_symbol_tree() 758 if stree is not None and stree is not self: 759 raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, " 760 f"current: {self.get_ori_cls_name()}.") 761 762 # Check if node is inserted between Input node 763 if base_node is not None and base_node.get_node_type() == NodeType.Input: 764 valid = True 765 if before_node: 766 valid = False 767 if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input: 768 valid = False 769 if not valid: 770 raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name()) 771 772 # save target name, which is used to provide unique target 773 if new_node.get_targets(): 774 for target in new_node.get_targets(): 775 self._target_namer.add_name(str(target)) 776 777 self._handle_custom_obj_in_normalized_args(new_node) 778 779 # Insert node into NodeManager 780 if node_manager is None: 781 if base_node is None: 782 raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.") 783 node_manager = base_node.get_node_manager() 784 785 # set node's _belong_symbol_tree 786 new_node.set_belong_symbol_tree(self) 787 788 if node_manager is self: 789 NodeManager.insert_node(self, new_node, base_node, before_node) 790 if insert_to_ast: 791 # update init-function-ast and construct-function-ast 792 self.insert_to_ast_while_insert_node(new_node, base_node, before_node) 793 else: 794 node_manager.insert_node(new_node, base_node, before_node, insert_to_ast) 795 796 # register code changed event observer, which is used to update _modified flag. 797 if new_node.get_node_type() == NodeType.Tree: 798 new_node.symbol_tree.reg_observer(self) 799 elif isinstance(new_node, NodeManager): 800 new_node.reg_observer(self) 801 802 return new_node 803 804 def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node: 805 """ 806 Append a node to SymbolTree. 807 808 Args: 809 node (Node): An instance of node to be appended. 810 append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is 811 True. 812 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 813 NodeManager of symboltree's construct function. 814 815 Returns: 816 An instance of node which has been appended to SymbolTree. 817 """ 818 if node_manager is None: 819 node_manager = self 820 return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast) 821 822 def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node: 823 """ 824 Append an original field node to SymbolTree. An original field node represents a node created from existing 825 statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update 826 while these nodes appending to SymbolTree. 827 This method is called while building SymbolTree usually. 828 829 Args: 830 node (Node): An instance of node to be appended. 831 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 832 NodeManager of symboltree's construct function. 833 834 Returns: 835 An instance of node which has been appended to SymbolTree. 836 """ 837 return self.append_node(node, node_manager, False) 838 839 def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None, 840 node_manager: NodeManager = None): 841 """ 842 Append an input node to SymbolTree corresponding to parameter of forward method of network class. 843 This method is called while building SymbolTree usually. 844 845 Args: 846 ast_node (ast.AST): A ast Node corresponding to current parameter. 847 param_name (str): A str represents name of parameter of forward method of network class. 848 default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which 849 means parameter has no default value. 850 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 851 NodeManager of symboltree's construct function. 852 853 Returns: 854 An instance of input node which has been appended to SymbolTree. 855 """ 856 if param_name == "self": 857 return 858 # check param_name duplicated 859 if node_manager is None: 860 node_manager = self 861 for input_node in node_manager.get_input_nodes(): 862 targets = input_node.get_targets() 863 if len(targets) != 1: 864 raise RuntimeError("targets should have 1 elements") 865 target: ScopedValue = targets[0] 866 if target.type != ValueType.NamingValue: 867 raise RuntimeError("target.type should equal to ValueType.NamingValue") 868 if target.scope != "": 869 raise RuntimeError("target.scope should be empty") 870 exist_param = target.value 871 if exist_param == param_name: 872 raise RuntimeError("input duplicated:", param_name) 873 input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}") 874 self.append_origin_field(input_node, node_manager) 875 876 def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, 877 node_manager: NodeManager = None) -> Optional[Node]: 878 """ 879 Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is 880 a list or a dict. 881 This method is called while building SymbolTree usually. 882 883 Args: 884 ast_scope (ast.AST): A ast node represents ast node of scope of node. 885 ast_node (ast.AST): A ast node represents ast node. 886 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 887 NodeManager of symboltree's construct function. 888 889 Returns: 890 An instance of python node if a new node has been appended to SymbolTree else None. 891 """ 892 if ast_node is None: 893 return None 894 if isinstance(ast_node, (list, dict)) and not ast_node: 895 return None 896 return self.append_python_node(ast_scope, ast_node, node_manager) 897 898 def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node: 899 """ 900 Append a python node to SymbolTree. 901 This method is called while building SymbolTree usually. 902 903 Args: 904 ast_scope (ast.AST): A ast node represents ast node of scope of node. 905 ast_node (ast.AST): A ast node represents ast node. 906 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to 907 NodeManager of symboltree's construct function. 908 909 Returns: 910 An instance of python node which has been appended to SymbolTree. 911 """ 912 logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__) 913 node_name = type(ast_node).__name__ 914 node = Node.create_python_node(ast_node, node_name) 915 if node_manager is None or node_manager is self: 916 NodeManager.append_python_node(self, node) 917 else: 918 node_manager.append_python_node(node) 919 return node 920 921 def set_output(self, return_value: str, arg_index: int, return_idx: int = 0, 922 node_manager: NodeManager = None) -> Node: 923 """ 924 Update return value of return of forward method of network class. 925 926 Args: 927 return_value (str): A str represents new return value. 928 arg_index (int): A int indicates which value in return to be updated. 929 return_idx (int): A int indicates which return node to be updated. Default: 0. 930 node_manager (NodeManager): NodeManager those asts belong to. Default: None, means 931 symboltree's construct function. 932 933 Returns: 934 An instance of node represents return node after updated. 935 """ 936 node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns() 937 if not node_returns: 938 raise RuntimeError("Current node_manager has no output") 939 if return_idx >= len(node_returns): 940 raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.") 941 node_return = node_returns[return_idx] 942 self.set_node_arg(node_return, arg_index, return_value) 943 return node_return 944 945 def erase_node(self, node_or_name: Union[Node, str]) -> Node: 946 """ 947 Erase a node from SymbolTree. 948 949 Topological relation will be updated. 950 951 Args: 952 node_or_name (Union[Node, str]): An instance of node or a str represents name of node. 953 954 Returns: 955 An instance of node which has been erased from SymbolTree. 956 957 Raises: 958 RuntimeError: If 'node_or_name' is not in current SymbolTree. 959 RuntimeError: If erase corresponding ast node failed. 960 """ 961 962 node = self._get_real_node(node_or_name) 963 if node is None: 964 raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) 965 # erase node in NodeManager 966 node_manager = node.get_node_manager() 967 968 logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, " 969 f"node_manager: {node_manager.get_manager_name()}, " 970 f"code: {astunparse.unparse(node.get_ast()).strip()}, " 971 f"node_name:{node.get_name()}") 972 973 if node_manager is self: 974 NodeManager.erase_node(self, node) 975 if isinstance(node, ControlFlow): 976 ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse) 977 else: 978 ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast()) 979 if not ret: 980 raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.") 981 else: 982 node_manager.erase_node(node) 983 node.set_belong_symbol_tree(None) 984 self._deleted_node.append(node.get_name()) 985 return node 986 987 def replace(self, old_node: Node, new_nodes: [Node]) -> Node: 988 """ 989 Replace an old_node with a node list. 990 991 Args: 992 old_node (Node): Node to be replaced. 993 new_nodes (list[Node]): Node list to replace in. 994 995 Returns: 996 Last node in new_nodes list. 997 998 Raises: 999 RuntimeError: If 'old_node' is isolated. 1000 RuntimeError: If 'old_node' is not belong to current SymbolTree. 1001 """ 1002 real_old_node = self._get_real_node(old_node) 1003 if real_old_node is None: 1004 raise RuntimeError("Old node is not belong to current SymbolTree:", old_node) 1005 # insert new_nodes into node_manager 1006 node_manager = real_old_node.get_node_manager() 1007 # insert new_nodes into NodeManager 1008 base_node = old_node 1009 for node in new_nodes: 1010 self.insert_node(node, base_node, False, node_manager, True) 1011 base_node = node 1012 self.erase_node(old_node) 1013 return new_nodes[-1] 1014 1015 def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]): 1016 """ 1017 Set argument of 'node'. 1018 1019 Args: 1020 node (Union[Node, str]): Node to be modified. Can be a node or name of node. 1021 index (int): Indicate which input being modified. 1022 arg (Union[ScopedValue, str]): New argument to been set. 1023 1024 Raises: 1025 RuntimeError: If 'node' is not belong to current SymbolTree. 1026 """ 1027 1028 real_node = self._get_real_node(node) 1029 if real_node is None: 1030 raise RuntimeError("Node is not belong to current SymbolTree: ", node) 1031 1032 new_arg, old_arg = node.set_arg(arg, index) 1033 node.get_node_manager().on_update_arg(node, index, old_arg, new_arg) 1034 1035 def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str], 1036 out_idx: Optional[int] = None): 1037 """ 1038 Set argument of 'dst_node' by another Node. 1039 1040 Args: 1041 dst_node (Node): Node to be modified. Can be a node or name of node. 1042 arg_idx (int): Indicate which input being modified. 1043 src_node (Node): Node as new input. Can be a node or name of node. 1044 out_idx ([int, optional]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None 1045 which means use first output of 'node_to_link' as new input. 1046 1047 Raises: 1048 RuntimeError: If 'dst_node' is not belong to current SymbolTree. 1049 RuntimeError: If 'src_node' is not belong to current SymbolTree. 1050 RuntimeError: If 'out_idx' is out of range. 1051 RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered. 1052 """ 1053 1054 real_dst_node = self._get_real_node(dst_node) 1055 if real_dst_node is None: 1056 raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node) 1057 real_src_node = self._get_real_node(src_node) 1058 if real_src_node is None: 1059 raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node) 1060 1061 targets = real_src_node.get_targets() 1062 if out_idx is None: 1063 if len(targets) != 1: 1064 raise RuntimeError("node should has one output when out_idx is not provided") 1065 out_idx = 0 1066 if out_idx >= len(targets): 1067 raise RuntimeError("out_idx out of range: ", out_idx) 1068 new_arg = targets[out_idx] 1069 real_dst_node.set_arg(new_arg, arg_idx) 1070 real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx) 1071 1072 def unique_name(self, name: str): 1073 """Get a unique name in the symboltree""" 1074 return self._target_namer.get_name(name) 1075 1076 def unique_func_name(self, name: str): 1077 """Get a unique function name in the symboltree""" 1078 if not hasattr(self._origin_network, name): 1079 return name 1080 suffix = 1 1081 while hasattr(self._origin_network, f"{name}_{suffix}"): 1082 suffix += 1 1083 return f"{name}_{suffix}" 1084 1085 def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]): 1086 """ 1087 Set target of `node` . 1088 1089 Args: 1090 node (Union[Node, str]): Node to be modified. Can be a node or name of node. 1091 index (int): Indicate which target being modified. 1092 arg (Union[ScopedValue, str]): New target to been set. 1093 1094 Raises: 1095 ValueError: If `node` is not belong to current SymbolTree. 1096 ValueError: If index of `node` 's target is greater than number of targets. 1097 """ 1098 1099 real_node = self._get_real_node(node) 1100 if real_node is None: 1101 raise ValueError("Node is not belong to current SymbolTree: ", node) 1102 if isinstance(target, str): 1103 target = ScopedValue.create_naming_value(target) 1104 targets = node.get_targets() 1105 if index >= len(targets): 1106 raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}") 1107 old_target = targets[index] 1108 targets[index] = target 1109 node.set_targets(targets) 1110 self._topo_mgr.on_update_target(node, index, old_target, target) 1111 1112 def all_nodes(self, subtree_nodes: bool = True): 1113 """ 1114 Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree. 1115 1116 Args: 1117 subtree_nodes (bool): Whether include nodes in subtree. Default: True. 1118 1119 Returns: 1120 A list of nodes. 1121 """ 1122 nodes = [] 1123 node_managers = [self] 1124 while node_managers: 1125 node_manager = node_managers.pop() 1126 nodes.extend(node_manager.nodes()) 1127 for node in node_manager.nodes(): 1128 if isinstance(node, NodeManager): 1129 node_managers.append(node) 1130 if subtree_nodes: 1131 for tree_node in self.get_tree_nodes(): 1132 stree = tree_node.symbol_tree 1133 nodes.extend(stree.all_nodes()) 1134 return nodes 1135 1136 def get_node_from_name(self, node_name: str): 1137 """ 1138 Get node from all NodeManagers in current symbol tree by `node_name`. 1139 1140 Args: 1141 node_name (str): A str represents name of node as key of query. 1142 1143 Returns: 1144 An instance of Node if found else None. 1145 """ 1146 node_managers = [self] 1147 while node_managers: 1148 node_manager = node_managers.pop() 1149 node = node_manager.get_node(node_name) 1150 if node: 1151 return node 1152 for node in node_manager.nodes(): 1153 if isinstance(node, NodeManager): 1154 node_managers.append(node) 1155 return None 1156 1157 def get_node_tabulate(self, all_nodes: bool = False) -> str: 1158 """ 1159 Get nodes information and nodes' topological relations. 1160 1161 Args: 1162 all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction 1163 nodes, CellContainer nodes and sub symbol trees. 1164 1165 Returns: 1166 String of nodes' information and topological relations. 1167 """ 1168 try: 1169 from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource 1170 except ImportError: 1171 logger.warning("print_node_tabulate relies on the library `tabulate`, " 1172 "which could not be found on this machine. Run `pip " 1173 "install tabulate` to install the library.") 1174 return "" 1175 dump_str = NodeManager.dump(self, self.get_manager_name()) 1176 if all_nodes: 1177 node_managers = [self] 1178 while node_managers: 1179 node_manager = node_managers.pop() 1180 for node in node_manager.nodes(): 1181 if isinstance(node, NodeManager): 1182 dump_str += node.dump(SymbolTree.get_node_full_name(node)) 1183 node_managers.append(node) 1184 for tree_node in self.get_tree_nodes(): 1185 stree = tree_node.symbol_tree 1186 dump_str += stree.get_node_tabulate(all_nodes) 1187 return dump_str 1188 1189 def dump(self): 1190 """Dump graph.""" 1191 dump_st = SymbolTreeDumper(self) 1192 dump_st.dump() 1193 1194 def check_body_exist(self, body, code_bodies): 1195 """Check whether body already exist in code_bodies""" 1196 # Check import ast node exist by saving import code string to self._tmp_import_strs 1197 if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)): 1198 import_str = astunparse.unparse(body) 1199 if import_str in self._tmp_import_strs: 1200 return True 1201 self._tmp_import_strs.append(import_str) 1202 return False 1203 1204 # Check ClassDef ast node exist by using AstClassFinder 1205 if isinstance(body, ast.ClassDef): 1206 if sys.version_info >= (3, 9): 1207 class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])) 1208 else: 1209 class_finder = AstClassFinder(ast.Module(body=code_bodies)) 1210 results = class_finder.find_all(body.name) 1211 return bool(results) 1212 1213 # Check FunctionDef ast node exist by using AstFunctionFinder 1214 if isinstance(body, ast.FunctionDef): 1215 if sys.version_info >= (3, 9): 1216 function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[])) 1217 else: 1218 function_finder = AstFunctionFinder(ast.Module(body=code_bodies)) 1219 results = function_finder.find_all(body.name) 1220 return bool(results) 1221 1222 return False 1223 1224 def deduplicate_unmodified_stree(self, code_bodies): 1225 """ 1226 Init function may be different even if stree is not modified manually, when subnets in stree is 1227 initialized by different arguments. 1228 In this case, we need to wait for code_bodies being fully generated, so that the name of subnets 1229 will be updated, then we can deduplicate again according to ast of init function. 1230 """ 1231 # prepare AstClassFinder and AstReplacer 1232 if sys.version_info >= (3, 9): 1233 class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])) 1234 name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[])) 1235 else: 1236 class_finder = AstClassFinder(ast.Module(body=code_bodies)) 1237 name_replacer = AstReplacer(ast.Module(body=code_bodies)) 1238 # deduplicate all unmodified strees in self._tmp_unmodified_strees 1239 deduplicated = False 1240 for _, unmodified_strees in self._tmp_unmodified_strees.items(): 1241 if len(unmodified_strees) <= 1: 1242 continue 1243 init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees] 1244 # If the index of an element is not its own, it means that it is a duplicate element 1245 to_be_erase = [] 1246 for idx, code in enumerate(init_func_codes): 1247 first_idx = init_func_codes.index(code) 1248 if first_idx != idx: 1249 first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name() 1250 duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name() 1251 logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.") 1252 # delete duplicated class from code_bodies 1253 results = class_finder.find_all(duplicated_stree_cls_name) 1254 for ast_cls in results: 1255 code_bodies.remove(ast_cls) 1256 # replace name of duplicated class in code_bodies to first_stree_cls_name 1257 name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name) 1258 # record deduplicated stree 1259 to_be_erase.append(idx) 1260 deduplicated = True 1261 # remove class in self._tmp_unmodified_strees 1262 for idx in reversed(to_be_erase): 1263 unmodified_strees.pop(idx) 1264 1265 # the name of subnets is updated, so we need to deduplicate again. 1266 if deduplicated: 1267 self._tmp_replacers.append(name_replacer) 1268 self.deduplicate_unmodified_stree(code_bodies) 1269 1270 def update_unmodified_stree(self, stree, code_bodies) -> bool: 1271 """ 1272 For the unmodified symbol tree, only one definition code remains in the generated code. 1273 Everywhere else calling this symbol tree will use the class in this definition code. 1274 """ 1275 # all modified ast.ClassDef will be exported to code 1276 if stree.is_modified(): 1277 logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.") 1278 return False 1279 # all un-modified ast.ClassDef only keep one instance 1280 unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network())) 1281 if not unmodified_strees: 1282 self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree] 1283 logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.") 1284 return False 1285 # Init function may be different even if stree is not modified, when subnets in stree is 1286 # initialized by different arguments. 1287 first_stree = unmodified_strees[0] 1288 first_stree_cls_name = first_stree.get_opt_cls_name() 1289 if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()): 1290 # init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees 1291 # and deduplicate later 1292 self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree) 1293 logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.") 1294 return False 1295 # Un-modified ast.ClassDef already exist in code_bodies, 1296 # replace class name to class name of first un-modified ast.ClassDef. 1297 if sys.version_info >= (3, 9): 1298 replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[])) 1299 else: 1300 replacer = AstReplacer(ast.Module(body=code_bodies)) 1301 logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.") 1302 replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name) 1303 self._tmp_replacers.append(replacer) 1304 return True 1305 1306 def init_code_bodies(self, code_bodies: list) -> int: 1307 """Init code bodied""" 1308 # Add basic imports 1309 code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)])) 1310 code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)])) 1311 code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0)) 1312 code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0)) 1313 code_bodies.append(ast.ImportFrom(module='mindspore.ops', 1314 names=[ast.alias(name='functional', asname='F')], level=0)) 1315 code_bodies.append(ast.Expr(ast.Name("#", ast.Load()))) 1316 # Add user custom codes into code_bodies 1317 custom_codes = self.get_custom_codes() 1318 for code_ast in custom_codes: 1319 code_bodies.append(code_ast) 1320 code_bodies.append(ast.Expr(ast.Name("#", ast.Load()))) 1321 return len(code_bodies) 1322 1323 def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int: 1324 """ 1325 Convert nodes in stree to code_bodies 1326 - Add external function asts into code_bodies 1327 - Add father class asts into code_bodies 1328 - Add import asts of symbol tree into code_bodies 1329 - Add user custom codes into code_bodies 1330 - Add class asts of symbol tree into code_bodies 1331 - Add subtrees to code_bodies 1332 """ 1333 insert_pos = dividing_pos 1334 # Add external asts into code_bodies 1335 for ast_func, import_asts in reversed(stree.get_external_ast().items()): 1336 if self.check_body_exist(ast_func, code_bodies): 1337 continue 1338 # add imports of external_ast 1339 self._tmp_import_strs.clear() 1340 for ast_import in import_asts: 1341 if not self.check_body_exist(ast_import, code_bodies): 1342 code_bodies.insert(insert_pos, ast_import) 1343 insert_pos += 1 1344 # add external_ast 1345 code_bodies.insert(insert_pos, ast_func) 1346 insert_pos += 1 1347 # add divide 1348 code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load()))) 1349 insert_pos += 1 1350 1351 # Add father class asts into code_bodies 1352 for ast_class, import_asts in stree.get_father_class_ast().items(): 1353 if self.check_body_exist(ast_class, code_bodies): 1354 continue 1355 # add imports of father class 1356 self._tmp_import_strs.clear() 1357 for ast_import in import_asts: 1358 if not self.check_body_exist(ast_import, code_bodies): 1359 code_bodies.insert(insert_pos, ast_import) 1360 insert_pos += 1 1361 # add ast of father class 1362 code_bodies.insert(insert_pos, ast_class) 1363 insert_pos += 1 1364 # add divide 1365 code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load()))) 1366 insert_pos += 1 1367 1368 # external functions and father class are above the dividing_pos to support deduplication. 1369 dividing_pos = insert_pos 1370 1371 # Add import asts of symbol tree into code_bodies 1372 self._tmp_import_strs.clear() 1373 for body in stree.get_import_asts(): 1374 if not self.check_body_exist(body, code_bodies): 1375 code_bodies.insert(insert_pos, body) 1376 insert_pos += 1 1377 1378 # Add class asts of symbol tree into code_bodies 1379 if stree.get_module_ast(): 1380 for body in stree.get_module_ast().body: 1381 if self.check_body_exist(body, code_bodies): 1382 continue 1383 code_bodies.insert(insert_pos, body) 1384 insert_pos += 1 1385 1386 # add divide 1387 code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load()))) 1388 insert_pos += 1 1389 1390 # Add subtrees to code_bodies 1391 for node in stree.get_tree_nodes(): 1392 sub_stree = node.symbol_tree 1393 # For the unmodified class, update class name to name of first class 1394 if self.update_unmodified_stree(sub_stree, code_bodies): 1395 continue 1396 dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos) 1397 1398 # return new dividing position 1399 return dividing_pos 1400 1401 def get_code(self) -> str: 1402 """ 1403 Get source code of modified network. 1404 1405 Returns: 1406 A str represents source code of modified network. 1407 """ 1408 self._tmp_import_strs.clear() 1409 self._tmp_unmodified_strees.clear() 1410 self._tmp_replacers.clear() 1411 code_bodies = [] 1412 begin_pos = self.init_code_bodies(code_bodies) 1413 self.convert_stree_to_code_bodies(self, code_bodies, begin_pos) 1414 self.deduplicate_unmodified_stree(code_bodies) 1415 if sys.version_info >= (3, 9): 1416 gencode_module = ast.Module(body=code_bodies, type_ignores=[]) 1417 else: 1418 gencode_module = ast.Module(body=code_bodies) 1419 SymbolTree._remove_unused_import(gencode_module) 1420 self._process_duplicate_name_modules(gencode_module) 1421 SymbolTree._remove_duplicated_import(gencode_module) 1422 SymbolTree._remove_arg_annotations(gencode_module) 1423 ast.fix_missing_locations(self._module_ast) 1424 code = astunparse.unparse(gencode_module) 1425 # Revert the class name to its original state 1426 for replacer in self._tmp_replacers: 1427 replacer.undo_all() 1428 return code 1429 1430 def get_network(self): 1431 """ 1432 Get modified network. 1433 1434 Returns: 1435 A network object. 1436 """ 1437 cls = self._get_cls_through_file() 1438 new_net = cls(self._origin_network) 1439 self._merge_origin_property(new_net) 1440 # update parameters' names to fix duplicated names bug 1441 # which occurs after inserting cell to celllist/sequentialcell 1442 new_net.update_parameters_name() 1443 return new_net 1444 1445 def set_saved_file_name(self, file_name: str): 1446 if file_name.endswith(".py"): 1447 self._saved_file_name = file_name 1448 else: 1449 self._saved_file_name = file_name + ".py" 1450 1451 def get_saved_file_name(self): 1452 return self._saved_file_name 1453 1454 def save_network_to_file(self): 1455 abs_path = os.path.abspath(self._saved_file_name) 1456 if os.path.isfile(abs_path): 1457 os.remove(abs_path) 1458 with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f: 1459 source = self.get_code() 1460 f.write(source.encode('utf-8')) 1461 f.flush() 1462 1463 1464 def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False): 1465 """Flatten nodes in ControlFlow node.""" 1466 if not isinstance(node, ControlFlow): 1467 raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.") 1468 upper_node_manager = node.get_node_manager() 1469 if isinstance(upper_node_manager, (SymbolTree, CallFunction)): 1470 ast_bodies = upper_node_manager.get_manager_ast().body 1471 elif isinstance(upper_node_manager, ControlFlow): 1472 ast_bodies = upper_node_manager.get_manager_ast() 1473 else: 1474 raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, " 1475 f"ControlFlow], but the node is in {type(upper_node_manager)}.") 1476 base_node = node.orelse_node if node.orelse_node else node.body_node 1477 for n in node.nodes()[:]: 1478 self.erase_node(n) 1479 self.insert_node(n, base_node, False, upper_node_manager, False) 1480 AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False) 1481 base_node = n 1482 self.erase_node(node) 1483 # remove another branch 1484 if erase_another_branch: 1485 if node.is_orelse: 1486 self.erase_node(node.body_node) 1487 elif node.orelse_node is not None: 1488 self.erase_node(node.orelse_node) 1489 # remove nodes after return node 1490 if erase_nodes_after_return: 1491 has_return = False 1492 for n in upper_node_manager.nodes(): 1493 if has_return: 1494 logger.warning(f"Node {n.get_name()} which is behind the flatten return node is " 1495 f"automatically erased.") 1496 self.erase_node(n) 1497 elif n.get_node_type() == NodeType.Output: 1498 has_return = True 1499 1500 def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool): 1501 """ 1502 Eval ast_node and get result, only used in control flow node. 1503 """ 1504 # ast.Constant can be check without eval 1505 if isinstance(ast_node, ast.Constant): 1506 return True, bool(ast.value) 1507 # Get the module where the code of ast_node is located 1508 file_path = inspect.getfile(type(self.get_origin_network())) 1509 module = None 1510 for m in list(sys.modules.values()): 1511 if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path): 1512 module = m 1513 break 1514 if not module: 1515 logger.warning("Failed to get module of ast_node.") 1516 return False, False 1517 # eval ast_node and get result 1518 logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}") 1519 ast_expr = ast.Expression(ast_node) 1520 ast_expr = ast.fix_missing_locations(ast_expr) 1521 try: 1522 # eval with ast make this operation free of instruction injection 1523 # pylint: disable=eval-used 1524 result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals()) 1525 except Exception as e: # pylint: disable=broad-except 1526 logger.debug(f"Cannot get result of ast_node by eval, err:{e}") 1527 return False, False 1528 logger.debug(f"Eval ast result success, result: {result}") 1529 return True, bool(result) 1530 1531 def flatten_static_if_control_flow(self): 1532 """ 1533 For static if control flow, flatten codes in branch which will be executed and erase another branch. 1534 """ 1535 for node in self.all_nodes()[:]: 1536 if not node.get_belong_symbol_tree(): 1537 # the node has been erased 1538 continue 1539 if isinstance(node, ControlFlow) and node.test_result is not None: 1540 stree = node.get_belong_symbol_tree() 1541 if node.test_result: 1542 stree.flatten_nodes(node.body_node, True, True) 1543 else: 1544 if node.orelse_node is not None: 1545 stree.flatten_nodes(node.orelse_node, True, True) 1546 else: 1547 stree.erase_node(node.body_node) 1548 1549 def add_custom_codes(self, code: str): 1550 """Add user custom codes""" 1551 code_ast = ast.parse(code) 1552 self._custom_codes.extend(code_ast.body) 1553 1554 def get_custom_codes(self) -> List[ast.AST]: 1555 """Add user custom codes""" 1556 return self._custom_codes 1557 1558 def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None): 1559 """ 1560 Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom. 1561 1562 When level_num = 0(e.g. from xxx import yyy), current path will be saved. 1563 When level_num = 1(e.g. from .xxx import yyy), current path will be saved. 1564 When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved. 1565 """ 1566 file_path = os.path.dirname(os.path.abspath(file_path)) 1567 file_path = os.path.normcase(file_path) 1568 file_path = os.path.normpath(file_path) 1569 if level_num > 1: 1570 for _ in range(level_num - 1): 1571 file_path = os.path.dirname(file_path) 1572 sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0] 1573 # add imports to import_asts of belonging_ast 1574 import_asts = self._get_imports_list_of_ast(belonging_ast) 1575 import_asts.append(ast.Import([ast.alias(name='sys', asname=None)])) 1576 import_asts.append(sys_path_append_ast) 1577 1578 def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None): 1579 """Save imports from file""" 1580 self.save_file_path_to_sys(0, file_path, belonging_ast) 1581 if not os.path.exists(file_path): 1582 raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.") 1583 with open(file_path, "r", encoding="utf-8") as f: 1584 source_code = f.read() 1585 import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node() 1586 if not import_nodes: 1587 return 1588 # add imports to import_asts of belonging_ast 1589 import_asts = self._get_imports_list_of_ast(belonging_ast) 1590 for import_node in import_nodes: 1591 import_node = SymbolTree._process_relative_import(import_node, file_path) 1592 if import_node: 1593 import_asts.append(import_node) 1594 1595 def add_import(self, module: types.ModuleType, name: str, belonging_ast: None): 1596 """add codes: from `module` import `name`""" 1597 if not isinstance(module, types.ModuleType): 1598 raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}") 1599 if not hasattr(module, name): 1600 logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.") 1601 return 1602 # add imports to import_asts of belonging_ast 1603 import_asts = self._get_imports_list_of_ast(belonging_ast) 1604 if module.__name__ == "__main__": 1605 # get attr from module instead of import to avoid duplicate execution of __main__ module 1606 code = f"{name} = getattr(sys.modules['__main__'], '{name}')" 1607 code_ast = ast.parse(code).body[0] 1608 import_asts.append(code_ast) 1609 elif module.__name__ == "builtins": 1610 # built-in functions are not need to be imported 1611 pass 1612 else: 1613 # add import of obj to ast 1614 func_file_path = inspect.getabsfile(module) 1615 func_file_path = os.path.normcase(func_file_path) 1616 prefix_paths = [] 1617 for path in sys.path: 1618 path = os.path.normcase(path) 1619 if func_file_path.startswith(path): 1620 prefix_paths.append(path) 1621 prefix_paths.sort(key=len, reverse=True) 1622 for path in prefix_paths: 1623 import_path = func_file_path[len(path):] 1624 import_str = import_path.replace(os.path.sep, '.') 1625 import_str = import_str[1:] # remove first '.' 1626 mod = import_str.rsplit('.', 1)[0] 1627 if SymbolTree._check_import(func_file_path[:len(path)], mod): 1628 import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0) 1629 import_asts.append(import_node) 1630 break 1631 else: 1632 self.save_file_path_to_sys(0, func_file_path, belonging_ast) 1633 mod = os.path.basename(func_file_path).rsplit('.')[0] 1634 import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0) 1635 import_asts.append(import_node) 1636 1637 def _get_imports_list_of_ast(self, belonging_ast: ast.AST): 1638 # get import_asts of belonging_ast 1639 import_asts = self._import_asts 1640 if belonging_ast is not None: 1641 if belonging_ast in self._father_class_ast: 1642 import_asts = self._father_class_ast.get(belonging_ast) 1643 elif belonging_ast in self._external_ast: 1644 import_asts = self._external_ast.get(belonging_ast) 1645 return import_asts 1646 1647 def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: 1648 if isinstance(node_or_name, str): 1649 return self.get_node(node_or_name) 1650 return node_or_name 1651 1652 def _handle_custom_obj_in_normalized_args(self, node: Node): 1653 """ 1654 Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj. 1655 1656 Args: 1657 node (Node): A Node whose arguments and keyword arguments to be handled. 1658 """ 1659 normalized_args: {str, ScopedValue} = {} 1660 for key, value in node.get_normalized_args().items(): 1661 if not isinstance(value, ScopedValue): 1662 raise TypeError("value should be ScopedValue, got: ", type(value)) 1663 if value.type == ValueType.CustomObjValue: 1664 # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue 1665 arg_name = self.unique_name(f"arg_{type(value.value).__name__}") 1666 setattr(self._origin_network, arg_name, value.value) 1667 # Add new code to __init__(): self.arg_name = obj.arg_name 1668 new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0] 1669 self._init_func_ast.body.append(new_ast) 1670 # Modify node's normalized_args: CustomObjValue -> self.arg_name 1671 normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self") 1672 else: 1673 normalized_args[key] = value 1674 node.set_normalized_args(normalized_args) 1675 1676 def _get_cls_through_file(self): 1677 """ 1678 Load rewritten network class of current SymbolTree. 1679 1. Get source code of current SymbolTree. 1680 2. Saving source code to a tempfile. 1681 3. Import rewritten network class using "__import__" function. 1682 1683 Returns: 1684 A class handle. 1685 """ 1686 file_path = os.getcwd() 1687 file_path = os.path.join(file_path, "rewritten_network") 1688 if not os.path.exists(file_path): 1689 try: 1690 os.mkdir(file_path, mode=0o700) 1691 except FileExistsError: 1692 pass 1693 file_name = f"{self._opt_cls_name}_{id(self)}.py" 1694 network_file = os.path.join(file_path, file_name) 1695 with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f: 1696 source = self.get_code() 1697 f.write(source.encode('utf-8')) 1698 f.flush() 1699 os.fsync(f) 1700 tmp_module_path, tmp_module_file = os.path.split(network_file) 1701 tmp_module_name = tmp_module_file[:-3] 1702 sys.path.append(tmp_module_path) 1703 tmp_module = None 1704 1705 i = 0 1706 while not tmp_module: 1707 spec = importlib.util.spec_from_file_location(tmp_module_name, network_file) 1708 if spec: 1709 tmp_module = importlib.util.module_from_spec(spec) 1710 spec.loader.exec_module(tmp_module) 1711 else: 1712 logger.warning(f"load module {tmp_module_name} failed, retrying.") 1713 if i > 10: 1714 break 1715 time.sleep(0.5) 1716 i += 1 1717 if not tmp_module: 1718 raise ImportError(f"load module {tmp_module_name} failed.") 1719 # Save new module to sys.modules to support inspect.getsource(). 1720 sys.modules[tmp_module_name] = tmp_module 1721 network_cls = getattr(tmp_module, self._opt_cls_name) 1722 if network_cls is None: 1723 raise RuntimeError("Can not find network class:", self._opt_cls_name) 1724 return network_cls 1725 1726 def _on_change(self, event: Event): 1727 self._modified = True 1728 self.changed(event) 1729 1730 def _cal_difference_set(self, input, other): 1731 """Calculate different set of two sets.""" 1732 set1 = set(input) 1733 set2 = set(other) 1734 return set1 - set2 1735 1736 def _merge_origin_property(self, new_net): 1737 """Merge property of two network.""" 1738 tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net)) 1739 new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys()) 1740 for name in new_attr_names: 1741 setattr(new_net, name, getattr(self._origin_network, name)) 1742 # merger cells 1743 cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys()) 1744 cells = self._cal_difference_set(cells, self._deleted_node) 1745 for c in cells: 1746 new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c]) 1747 # merge primitives 1748 # pylint: disable=protected-access 1749 primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys()) 1750 for p in primitives: 1751 new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access 1752 1753 def _process_duplicate_name_modules(self, module_ast: ast.Module): 1754 """Adjust names of imported modules with same name and different import path.""" 1755 # {name1: [path1, path2, ...], ...} 1756 name_path_dict: Dict[str, List[str]] = {} 1757 # names of modules need to be suffixed: {name1: suffixed_name1, ...} 1758 name_need_suffix: Dict[str, str] = {} 1759 # used to record replace actions in ast.ImportFrom 1760 import_replacer = AstReplacer(None) 1761 self._tmp_replacers.append(import_replacer) 1762 1763 def suffix_alias(alias: ast.alias, suffix: int): 1764 """suffix the name of alias in ast.ImportFrom""" 1765 new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}" 1766 import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access 1767 alias.asname = new_name 1768 return new_name 1769 1770 def is_divider(ast_node): 1771 """judge if ast node is divider of new class or function by checking ast.Expr of '#'.""" 1772 return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#' 1773 1774 def record_imports(ast_node: ast.ImportFrom): 1775 """record name and path of imported modules to find the duplicate name modules.""" 1776 for alias in ast_node.names[:]: 1777 name = alias.asname if alias.asname else alias.name 1778 if name == '*': 1779 continue 1780 # current name is firstly imported, just record it 1781 if name not in name_path_dict: 1782 name_path_dict[name] = [ast_node.module] 1783 continue 1784 # current name is imported before, check whether it is a duplicated name 1785 for idx, path in enumerate(name_path_dict[name]): 1786 if path.startswith(ast_node.module): 1787 # e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A' 1788 # then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx} 1789 name_path_dict[name][idx] = ast_node.module 1790 if idx > 0: 1791 name_need_suffix[name] = suffix_alias(alias, idx) 1792 break 1793 elif ast_node.module.startswith(path): 1794 # e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A' 1795 # then we just need to update name to A_{idx} 1796 if idx > 0: 1797 name_need_suffix[name] = suffix_alias(alias, idx) 1798 break 1799 else: 1800 # current name is imported from a new path, save the path and update the name 1801 name_path_dict[name].append(ast_node.module) 1802 name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1) 1803 1804 def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]): 1805 """suffix names in ast.ClassDef or ast.FunctionDef""" 1806 if not name_need_suffix: 1807 return 1808 name_replacer = AstReplacer(ast_node) 1809 self._tmp_replacers.append(name_replacer) 1810 for name, new_name in name_need_suffix.items(): 1811 name_replacer.replace_all(name, new_name) 1812 1813 for ast_node in module_ast.body: 1814 if isinstance(ast_node, ast.ImportFrom): 1815 record_imports(ast_node) 1816 if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)): 1817 suffix_names_in_ast(ast_node) 1818 if is_divider(ast_node): 1819 name_need_suffix.clear() 1820