• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""SymbolTree class define of Rewrite according to forward function of a network."""
16import stat
17from typing import Optional, Union, Tuple, Any, Dict, List
18import types
19import os
20import sys
21import ast
22import importlib.util
23import time
24import inspect
25from textwrap import dedent
26from collections import OrderedDict
27
28from mindspore.nn import Cell
29from mindspore import log as logger
30from .symbol_tree_dumper import SymbolTreeDumper
31from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager
32from ..api.node_type import NodeType
33from ..api.scoped_value import ScopedValue, ValueType
34from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \
35    AstImportFinder
36from ..common.namer import TargetNamer, NodeNamer, ClassNamer
37from ..common.observer import Observer
38from ..common.observable import Observable
39from ..common.event import Event
40
41if sys.version_info >= (3, 9):
42    import ast as astunparse # pylint: disable=reimported, ungrouped-imports
43else:
44    import astunparse
45
46class Position:
47    """
48    Position indicates a source code position in one network.
49
50    Rewrite recommend using class method `create()` of position rather than constructor of Position.
51
52    Args:
53        symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree.
54        node (Node): A handler of Node indicated position is around which Node.
55        before_node (bool): A bool indicated position is before or after the 'node'.
56    """
57
58    def __init__(self, symbol_tree, node, before_node: bool):
59        self.symbol_tree = symbol_tree
60        self.node = node
61        self.before_node = before_node
62
63    @classmethod
64    def create(cls, symbol_tree, node, before_node):
65        """
66        Class method of Position. Return None when symbol_tree or node is None.
67
68        Args:
69            symbol_tree: A handler of SymbolTree indicated position in which SymbolTree.
70            node: A handler of Node indicated position is around which Node.
71            before_node (bool): A bool indicated position is before or after the 'node'.
72
73        Returns:
74            A Position.
75        """
76        if symbol_tree is None or node is None:
77            return None
78        return Position(symbol_tree, node, before_node)
79
80
81class FieldFinder(AstFinder):
82    """
83    Check whether field exist in specific scope.
84
85    Args:
86        scope (ast.AST): An instance of ast node as search scope.
87    """
88
89    def __init__(self, scope: ast.AST):
90        super().__init__(scope)
91        self._result = False
92        self._field_name = ""
93
94    def visit_Attribute(self, node: ast.Attribute) -> Any:
95        """Visit a node of type ast.Attribute."""
96        value = node.value
97        if not isinstance(value, ast.Name):
98            return super(FieldFinder, self).generic_visit(node)
99        if value.id != "self":
100            return super(FieldFinder, self).generic_visit(node)
101        if node.attr == self._field_name:
102            self._result = True
103        return super(FieldFinder, self).generic_visit(node)
104
105    def check(self, field) -> bool:
106        """
107        Check whether `field` exist in scope.
108
109        Args:
110            field (str): A string indicates target field name.
111
112        Returns:
113            A bool indicate whether `field` exist in scope.
114        """
115        self._result = False
116        self._field_name = field
117        self.visit(self._scope)
118        return self._result
119
120
121class SymbolTree(Observer, Observable, NodeManager):
122    """
123    A symbol-tree usually corresponding to forward method of a network.
124
125    Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree rather than invoking constructor
126    of SymbolTree directly.
127
128    Args:
129        origin_network (Cell): A handler to original network instance.
130        module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
131    """
132    # whether parse CallFunction node inserted by user.
133    _unparse_inserted_function = True
134
135    def __init__(self, origin_network: Cell, module_ast: ast.Module):
136        Observer.__init__(self)
137        Observable.__init__(self)
138        self._node_namer = NodeNamer()
139        self._node_namer.add_name('obj')
140        NodeManager.__init__(self)
141        NodeManager.set_manager_node_namer(self, self._node_namer)
142        NodeManager.reg_observer(self, observer=self)
143        # init unique-namers
144        self._target_namer = TargetNamer()
145        # input arguments of function
146        self._ori_cls_name = type(origin_network).__name__
147        self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
148        NodeManager.set_manager_name(self, self._opt_cls_name)
149        self._origin_network = origin_network
150        self._module_ast: ast.Module = module_ast
151        self._import_asts: Optional[ast.Ast] = []
152        self._class_ast: Optional[ast.ClassDef] = None
153        self._root_ast: Optional[ast.FunctionDef] = None
154        self._init_func_ast: Optional[ast.FunctionDef] = None
155        self._deleted_field = {}
156        self._deleted_node = []
157        # {ast_function: [import_asts]}
158        self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict()
159        # {ast_class: [import_asts]}
160        self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict()
161        self._modified = False
162        self._saved_file_name = "./network_define.py"
163        # used to insert "sys.path.append(xxx)"
164        self._net_file_paths = []
165        self._tmp_import_strs = []
166        self._tmp_unmodified_strees: {type, List[SymbolTree]} = {}
167        self._tmp_replacers = []
168        # user custom codes
169        self._custom_codes: List[ast.AST] = []
170        # local primitive instances initialized during forward method, e.g. abs_inst = P.Abs()
171        self._local_prim_inits: List[Node] = []
172
173    @staticmethod
174    def _remove_unused_import(module_ast):
175        """remove unused import in self._module_ast"""
176        import_nodes: List[Union[ast.Import, ast.ImportFrom]] = []
177
178        def is_divider(ast_node):
179            """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
180            return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
181
182        for ast_node in module_ast.body[:]:
183            if isinstance(ast_node, (ast.Import, ast.ImportFrom)):
184                import_nodes.append(ast_node)
185            if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
186                str_checker = StrChecker(ast_node)
187                for import_node in import_nodes:
188                    for alias in import_node.names[:]:
189                        name = alias.asname if alias.asname else alias.name
190                        if name == '*':
191                            continue
192                        if not str_checker.check(name):
193                            import_node.names.remove(alias)
194                    if not import_node.names:
195                        module_ast.body.remove(import_node)
196            if is_divider(ast_node):
197                import_nodes.clear()
198
199    @staticmethod
200    def _remove_duplicated_import(module_ast):
201        """Remove duplicated import of 'net'."""
202        imports = set()
203        futures = set()
204        names = set()
205
206        class TransImportNode(ast.NodeTransformer):
207            """Find all import nodes from input ast node."""
208
209            def visit_ClassDef(self, node: ast.ClassDef) -> Any:
210                if node.name not in names:
211                    names.add(node.name)
212                    return node
213                return None
214
215            def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
216                if node.name not in names:
217                    names.add(node.name)
218                    return node
219                return None
220
221            def visit_Try(self, node: ast.Try) -> Any:
222                if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
223                    import_str = astunparse.unparse(node)
224                    if import_str not in imports:
225                        imports.add(import_str)
226                        return node
227                return None
228
229            def visit_Import(self, node: ast.Import) -> Any:
230                import_str = astunparse.unparse(node)
231                if import_str not in imports:
232                    imports.add(import_str)
233                    return node
234                return None
235
236            def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
237                """
238                Once the father class 'A' is defined in the current module, all the next imported class 'A' should
239                be removed. e.g.
240                    def class A():
241                        ...
242                    from xxx import A, B
243                    =>
244                    def class A():
245                        ...
246                    from xxx import B
247                """
248                import_str = astunparse.unparse(node)
249
250                if import_str not in imports:
251                    imports.add(import_str)
252                    # remove "__future__" module
253                    if node.module == '__future__':
254                        futures.add(node.module)
255                        return None
256                    # remove modules which have been defined in the code file
257                    # it occurs when class A is a father class and other sub-classes import A
258                    for alias in node.names[:]:
259                        if alias.name in names:
260                            node.names.remove(alias)
261                    # if the alias(es) in node.names are all removed, this import statement should be removed
262                    if not node.names:
263                        return None
264                    return node
265                return None
266
267        get_node_handler = TransImportNode()
268        get_node_handler.generic_visit(module_ast)
269
270    @staticmethod
271    def _remove_arg_annotations(module_ast):
272        """Remove annotations in ast.arg to avoid 'xxx is not defined'."""
273        ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg)
274        for ast_arg in ast_args:
275            ast_arg.annotation = None
276
277    @staticmethod
278    def _check_import(import_path: str, import_module: str):
279        """
280        Check whether import operation is valid when importing module from specific path.
281        """
282        if import_path not in sys.path:
283            sys.path.append(import_path)
284        try:
285            importlib.import_module(name=import_module)
286        except (ValueError, ImportError) as e:
287            logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
288            return False
289        except Exception as e: # pylint: disable=W0703
290            logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
291            return False
292        return True
293
294    @staticmethod
295    def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str):
296        """Process relative imports"""
297        file_path = os.path.normcase(file_path)
298        file_path = os.path.normpath(file_path)
299        if isinstance(import_node, ast.ImportFrom):
300            # pad the ImportFrom with parent path
301            # e.g. from ..C import xxx -> from A.B.C import xxx
302            import_module = SymbolTree._get_valid_import_info(import_node, file_path)
303            if import_module:
304                import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
305        return import_node
306
307    @staticmethod
308    def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str):
309        """Get valid import info while import_node.module is at form of relative path"""
310        file_path = os.path.dirname(os.path.abspath(file_path))
311        # get real path from import_node.level
312        # from .(A) import xxx: current path
313        # from ..(A) import xxx: last level path
314        level = import_node.level
315        # from A import xxx: it does not need to pad, directly return the module name
316        if level == 0:
317            return import_node.module
318        if level > 1:
319            for _ in range(level - 1):
320                file_path = os.path.dirname(file_path)
321        file_path_tmp = file_path[:]
322        max_level_count = file_path.count(os.path.sep) - 1
323        level_count = 0
324        # suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
325        suffix = ''
326        if import_node.module:
327            suffix = '.' + import_node.module
328        while level_count < max_level_count:
329            file_path_tmp = os.path.dirname(file_path_tmp)
330            if file_path_tmp not in sys.path:
331                logger.debug(f"{file_path_tmp} not in sys.path, try upper level.")
332                level_count += 1
333                continue
334            import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix
335            if SymbolTree._check_import(file_path_tmp, import_module):
336                # try test code success
337                return import_module
338            # test import ast failed, try upper level
339            level_count += 1
340            logger.info(f"Try upper level.")
341        # try codes with all level failed
342        logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.")
343        return None
344
345    @staticmethod
346    def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager):
347        """update ast when inserting NodeType.Input node"""
348        if not isinstance(node_manager, (SymbolTree, CallFunction)):
349            raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of "
350                             f"CallFunction, but get {type(node_manager)}")
351        # insert a new input
352        node_manager.get_input_nodes().append(new_node)
353        ast_function: ast.FunctionDef = node_manager.get_manager_ast()
354        arg: str = new_node.get_targets()[0].value
355        ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
356        AstModifier.append_arg_to_function(ast_function, ast_arg)
357
358    @staticmethod
359    def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool,
360                                                  node_manager: NodeManager, stree):
361        """update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node"""
362        # create a new assign statement
363        ast_assign = new_node.get_ast()
364        if ast_assign is None:
365            func_name = stree.unique_func_name(new_node.get_name())
366            new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
367            ast_assign = new_node.update_ast_node()
368        if not isinstance(ast_assign, ast.Assign):
369            raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
370        # Save instance into _origin_network.
371        setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
372        # Insert ast to __init__ function
373        if isinstance(new_node, TreeNode):
374            init_code = f"{new_node.get_func_name()} = " \
375                        f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
376        else:
377            init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}"
378        init_ast = ast.parse(init_code).body[0]
379        AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast)
380        # Insert ast to construct_function/class_internal_function
381        ast_base_node = base_node.get_ast() if base_node else None
382        ast_node_manager = node_manager.get_manager_ast()
383        if not ast_node_manager:
384            raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
385                               "when inserting the ast.")
386        AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
387
388    @staticmethod
389    def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool,
390                                            node_manager: NodeManager, stree: 'SymbolTree'):
391        """update ast when inserting NodeType.CallFunction node"""
392        func_name = str(new_node.get_func_name())
393        # create a new assign statement
394        ast_assign = new_node.get_ast()
395        if ast_assign is None:
396            ast_assign = new_node.update_ast_node()
397        # Insert ast to node_manager
398        ast_base_node = base_node.get_ast() if base_node else None
399        ast_node_manager = node_manager.get_manager_ast()
400        if not ast_node_manager:
401            raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
402                               "when inserting the ast.")
403        AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
404        # Ignore Python builtin functions
405        func_obj = new_node.get_instance()
406        if isinstance(func_obj, types.BuiltinFunctionType):
407            logger.warning(f"Ignore built in function: {func_name}")
408            return
409        # get ast.FunctionDef
410        source_code = inspect.getsource(func_obj)
411        ast_functiondef = ast.parse(dedent(source_code)).body[0]
412        if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef):
413            logger.debug(f"import '{func_name}' to access function object")
414            # add import to make sure that the function object can be accessed.
415            module = inspect.getmodule(func_obj)
416            top_node_manager = node_manager.get_top_manager()
417            belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast()
418            stree.add_import(module, func_name, belonging_ast)
419            return
420        # parse nodes in inserted function.
421        new_node.set_manager_ast(ast_functiondef)
422        new_node.set_manager_node_namer(stree.get_node_namer())
423        stree.get_external_ast()[ast_functiondef] = []
424        # import module which function defined in
425        func_file_path = inspect.getabsfile(func_obj)
426        stree.save_imports_from_file(func_file_path, ast_functiondef)
427        # expand ast codes in function
428        from ..ast_helpers import AstFlattener
429        ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree)
430        # parse ast codes into CallFunction Node
431        from ..parsers import ParserRegister
432        parser = ParserRegister.instance().get_parser(ast.FunctionDef)
433        parser.process(stree, ast_functiondef, node_manager=new_node)
434
435    @staticmethod
436    def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool):
437        """ insert_to_ast_while_insert_node. """
438        stree = new_node.get_belong_symbol_tree()
439        if not stree:
440            raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.")
441        node_manager = new_node.get_node_manager()
442        if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)):
443            raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can "
444                             f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}")
445        if new_node.get_node_type() == NodeType.Input:
446            SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager)
447        elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
448            SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager,
449                                                                 stree)
450        elif new_node.get_node_type() == NodeType.CallFunction:
451            SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree)
452        else:
453            raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be "
454                             f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got "
455                             f"{new_node.get_node_type()}.")
456
457    @staticmethod
458    def get_node_full_name(node: Node) -> str:
459        """Get full name of node"""
460        name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name()
461        # traverse node_manager with type of Node
462        node_manager = node.get_node_manager()
463        while isinstance(node_manager, Node):
464            name = f"{node_manager.get_manager_name()}.{name}"
465            node_manager = node_manager.get_node_manager()
466        # type of node_manager is SymbolTree now
467        name = f"{node_manager.get_manager_name()}.{name}"
468        return name
469
470    def local_prim_inits(self) -> List[Node]:
471        """get local primitives constructed during forward method"""
472        return self._local_prim_inits
473
474    def finish_build(self):
475        """Add Event.TopologicalChangeEvent event when build is finished."""
476        self.add_event(Event.TopologicalChangeEvent)
477
478    def get_ori_cls_name(self) -> str:
479        """
480        Get class name of original network.
481
482        Returns:
483            A str represents class name of original network.
484        """
485        return self._ori_cls_name
486
487    def get_opt_cls_name(self) -> str:
488        """
489        Get class name of rewritten network.
490
491        Returns:
492            A str represents class name of rewritten network.
493        """
494        return self._opt_cls_name
495
496    def get_module_ast(self):
497        """
498        Getter of `_module_ast`.
499
500        Returns:
501            An instance of ast.AST represents ast node of corresponding module.
502        """
503        return self._module_ast
504
505    def set_module_ast(self, ast_node: ast.Module):
506        """
507        Setter of _module_ast.
508
509        Args:
510            ast_node (ast.Module): An instance of ast.Module represents ast node of module of corresponding network
511                                   class.
512        """
513        self._module_ast = ast_node
514
515    def get_ast_root(self):
516        """
517        Getter of `_root_ast`.
518
519        Returns:
520            An instance of ast.AST represents ast node of corresponding forward method.
521        """
522        return self._root_ast
523
524    def set_ast_root(self, ast_node: ast.FunctionDef):
525        """
526        Setter of _root_ast.
527
528        Args:
529            ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of
530                                        corresponding network class.
531        """
532        self._root_ast = ast_node
533        NodeManager.set_manager_ast(self, ast_node)
534
535    def get_class_ast(self):
536        """
537        Getter of `_class_ast`.
538
539        Returns:
540            An instance of ast.ClassDef represents ast node of corresponding network class.
541        """
542        return self._class_ast
543
544    def set_class_ast(self, ast_node: ast.ClassDef):
545        """
546        Setter of `_class_ast`.
547
548        Args:
549            ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
550        """
551        self._class_ast = ast_node
552
553    def get_init_func_ast(self):
554        """
555        Getter of _init_func_ast.
556
557        Returns:
558            An instance of ast.FunctionDef represents ast node of init method of corresponding network class.
559        """
560        return self._init_func_ast
561
562    def set_init_func_ast(self, ast_node: ast.FunctionDef):
563        """
564        Setter of _init_func_ast.
565
566        Args:
567            ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of
568                                        corresponding network class.
569        """
570        self._init_func_ast = ast_node
571
572    def get_origin_network(self):
573        """
574        Getter of `_origin_network`.
575
576        Returns:
577            An instance of Cell which represents original network.
578        """
579        return self._origin_network
580
581    def get_nodes_dict(self):
582        """Get dict of nodes"""
583        return self._nodes
584
585    def get_node_namer(self):
586        """Get _node_namer"""
587        return self._node_namer
588
589    def is_modified(self):
590        """
591        Check whether symbol tree is modified.
592
593        Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
594        the symbol tree is created.
595        """
596        return self._modified
597
598    def set_modified_true(self):
599        """
600        Set self._modified true.
601
602        Self._modified is set true when 'if' exists in the original network.
603        In this situation, different original network instance tends to be different.
604        Hence, the class name should be updated.
605        """
606        self._modified = True
607
608    def get_import_asts(self):
609        """Get _import_asts"""
610        return self._import_asts
611
612    def get_external_ast(self):
613        """Get _external_ast"""
614        return self._external_ast
615
616    def get_father_class_ast(self):
617        """Get _father_class_ast"""
618        return self._father_class_ast
619
620    def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
621        """
622        Getter of inputs in topological relation of current 'node_or_name'.
623
624        Args:
625            node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
626
627        Returns:
628            A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if
629            'node_or_name' not belong to current SymbolTree.
630        """
631
632        real_node: Optional[Node] = self._get_real_node(node_or_name)
633        if real_node is None:
634            logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
635            return []
636        return node_or_name.get_inputs()
637
638    def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]:
639        """
640        Getter of outputs in topological relation of current 'node_or_name'.
641
642        Args:
643            node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
644
645        Returns:
646            A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if
647            'node_or_name' not belong to current SymbolTree.
648        """
649
650        real_node: Optional[Node] = self._get_real_node(node_or_name)
651        if real_node is None:
652            logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
653            return []
654        if real_node.get_node_type() == NodeType.Output:
655            return []
656        node_users = []
657        for target_users in real_node.get_target_users().values():
658            if not target_users:
659                continue
660            if target_users not in node_users:
661                node_users.extend(target_users)
662        return node_users
663
664    def before(self, node_or_name: Union[Node, str]) -> Position:
665        """
666        Get insert position before 'node_or_name' in source code list.
667        Consider using symbol_tree, node and before/after as position for sub-tree feature.
668
669        Note:
670            Topological order is not determined here which is determined by arguments of node and updated by
671            TopologicalManager automatically.
672
673        Args:
674            node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
675
676        Returns:
677            A Position represents an insert point.
678
679        Raises:
680            AssertError: If 'node_or_name' is not a Node or a str
681            RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
682                SymbolTree.
683        """
684
685        node = self._get_real_node(node_or_name)
686        if node is None:
687            raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
688        return Position.create(node.get_belong_symbol_tree(), node, True)
689
690    def after(self, node_or_name: Union[Node, str]) -> Position:
691        """
692        Get insert position after 'node_or_name' in source code list.
693        Consider using symbol_tree, node and before/after as position for sub-tree feature.
694
695        Note:
696            Topological order is not determined here which is determined by arguments of node and updated by
697            TopologicalManager automatically.
698
699        Args:
700            node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
701
702        Returns:
703            A Position represents an insert point.
704
705        Raises:
706            AssertError: If 'node_or_name' is not a Node or a str
707            RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
708                SymbolTree.
709        """
710        node = self._get_real_node(node_or_name)
711        if node is None:
712            raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
713        return Position.create(node.get_belong_symbol_tree(), node, False)
714
715    def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
716                    insert_to_ast: bool = True):
717        """
718        Insert a node before or after base_node.
719
720        Note:
721            Name of node will be unique while inserting node into SymbolTree.
722
723            ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will
724            be saved in global_vars dict while inserting node into SymbolTree.
725
726            Targets of node will be unique while inserting node into SymbolTree.
727
728            A field instantiation statement will be added into "init" function of network class using node name as field
729            name when `insert_to_ast` is True while inserting node into SymbolTree.
730
731            An assign statement represents invoking to this node will be added into forward function of network class
732            corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into
733            SymbolTree.
734
735            Topological relation is updated and inputs of corresponding node is updated.
736
737        Args:
738            new_node (Node): Node to be inserted.
739            base_node (Node): New node will be inserted before or after base_node.
740            before_node (bool): Indicate whether new node is inserted before base_node.
741            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
742                NodeManager of symboltree's construct function.
743            insert_to_ast (bool): Indicate whether ast nodes need to be updated.
744
745        Returns:
746            An instance of node which has been inserted into SymbolTree.
747
748        Raises:
749            ValueError: Node in the SymbolTree is inserted into SymbolTree again.
750            RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
751        """
752        if new_node.get_belong_symbol_tree():
753            raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
754
755        # Check if base_node in current SymbolTree
756        if base_node is not None:
757            stree = base_node.get_belong_symbol_tree()
758            if stree is not None and stree is not self:
759                raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
760                                 f"current: {self.get_ori_cls_name()}.")
761
762        # Check if node is inserted between Input node
763        if base_node is not None and base_node.get_node_type() == NodeType.Input:
764            valid = True
765            if before_node:
766                valid = False
767            if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
768                valid = False
769            if not valid:
770                raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
771
772        # save target name, which is used to provide unique target
773        if new_node.get_targets():
774            for target in new_node.get_targets():
775                self._target_namer.add_name(str(target))
776
777        self._handle_custom_obj_in_normalized_args(new_node)
778
779        # Insert node into NodeManager
780        if node_manager is None:
781            if base_node is None:
782                raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
783            node_manager = base_node.get_node_manager()
784
785        # set node's _belong_symbol_tree
786        new_node.set_belong_symbol_tree(self)
787
788        if node_manager is self:
789            NodeManager.insert_node(self, new_node, base_node, before_node)
790            if insert_to_ast:
791                # update init-function-ast and construct-function-ast
792                self.insert_to_ast_while_insert_node(new_node, base_node, before_node)
793        else:
794            node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
795
796        # register code changed event observer, which is used to update _modified flag.
797        if new_node.get_node_type() == NodeType.Tree:
798            new_node.symbol_tree.reg_observer(self)
799        elif isinstance(new_node, NodeManager):
800            new_node.reg_observer(self)
801
802        return new_node
803
804    def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
805        """
806        Append a node to SymbolTree.
807
808        Args:
809            node (Node): An instance of node to be appended.
810            append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
811                True.
812            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
813                NodeManager of symboltree's construct function.
814
815        Returns:
816            An instance of node which has been appended to SymbolTree.
817        """
818        if node_manager is None:
819            node_manager = self
820        return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
821
822    def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
823        """
824        Append an original field node to SymbolTree. An original field node represents a node created from existing
825        statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
826        while these nodes appending to SymbolTree.
827        This method is called while building SymbolTree usually.
828
829        Args:
830            node (Node): An instance of node to be appended.
831            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
832                NodeManager of symboltree's construct function.
833
834        Returns:
835            An instance of node which has been appended to SymbolTree.
836        """
837        return self.append_node(node, node_manager, False)
838
839    def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
840                          node_manager: NodeManager = None):
841        """
842        Append an input node to SymbolTree corresponding to parameter of forward method of network class.
843        This method is called while building SymbolTree usually.
844
845        Args:
846            ast_node (ast.AST): A ast Node corresponding to current parameter.
847            param_name (str): A str represents name of parameter of forward method of network class.
848            default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
849                means parameter has no default value.
850            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
851                NodeManager of symboltree's construct function.
852
853        Returns:
854            An instance of input node which has been appended to SymbolTree.
855        """
856        if param_name == "self":
857            return
858        # check param_name duplicated
859        if node_manager is None:
860            node_manager = self
861        for input_node in node_manager.get_input_nodes():
862            targets = input_node.get_targets()
863            if len(targets) != 1:
864                raise RuntimeError("targets should have 1 elements")
865            target: ScopedValue = targets[0]
866            if target.type != ValueType.NamingValue:
867                raise RuntimeError("target.type should equal to ValueType.NamingValue")
868            if target.scope != "":
869                raise RuntimeError("target.scope should be empty")
870            exist_param = target.value
871            if exist_param == param_name:
872                raise RuntimeError("input duplicated:", param_name)
873        input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
874        self.append_origin_field(input_node, node_manager)
875
876    def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
877                               node_manager: NodeManager = None) -> Optional[Node]:
878        """
879        Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
880        a list or a dict.
881        This method is called while building SymbolTree usually.
882
883        Args:
884            ast_scope (ast.AST): A ast node represents ast node of scope of node.
885            ast_node (ast.AST): A ast node represents ast node.
886            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
887                NodeManager of symboltree's construct function.
888
889        Returns:
890            An instance of python node if a new node has been appended to SymbolTree else None.
891        """
892        if ast_node is None:
893            return None
894        if isinstance(ast_node, (list, dict)) and not ast_node:
895            return None
896        return self.append_python_node(ast_scope, ast_node, node_manager)
897
898    def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
899        """
900        Append a python node to SymbolTree.
901        This method is called while building SymbolTree usually.
902
903        Args:
904            ast_scope (ast.AST): A ast node represents ast node of scope of node.
905            ast_node (ast.AST): A ast node represents ast node.
906            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
907                NodeManager of symboltree's construct function.
908
909        Returns:
910            An instance of python node which has been appended to SymbolTree.
911        """
912        logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
913        node_name = type(ast_node).__name__
914        node = Node.create_python_node(ast_node, node_name)
915        if node_manager is None or node_manager is self:
916            NodeManager.append_python_node(self, node)
917        else:
918            node_manager.append_python_node(node)
919        return node
920
921    def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
922                   node_manager: NodeManager = None) -> Node:
923        """
924        Update return value of return of forward method of network class.
925
926        Args:
927            return_value (str): A str represents new return value.
928            arg_index (int): A int indicates which value in return to be updated.
929            return_idx (int): A int indicates which return node to be updated. Default: 0.
930            node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
931                symboltree's construct function.
932
933        Returns:
934            An instance of node represents return node after updated.
935        """
936        node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
937        if not node_returns:
938            raise RuntimeError("Current node_manager has no output")
939        if return_idx >= len(node_returns):
940            raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
941        node_return = node_returns[return_idx]
942        self.set_node_arg(node_return, arg_index, return_value)
943        return node_return
944
945    def erase_node(self, node_or_name: Union[Node, str]) -> Node:
946        """
947        Erase a node from SymbolTree.
948
949        Topological relation will be updated.
950
951        Args:
952            node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
953
954        Returns:
955            An instance of node which has been erased from SymbolTree.
956
957        Raises:
958            RuntimeError: If 'node_or_name' is not in current SymbolTree.
959            RuntimeError: If erase corresponding ast node failed.
960        """
961
962        node = self._get_real_node(node_or_name)
963        if node is None:
964            raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
965        # erase node in NodeManager
966        node_manager = node.get_node_manager()
967
968        logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
969                     f"node_manager: {node_manager.get_manager_name()}, "
970                     f"code: {astunparse.unparse(node.get_ast()).strip()}, "
971                     f"node_name:{node.get_name()}")
972
973        if node_manager is self:
974            NodeManager.erase_node(self, node)
975            if isinstance(node, ControlFlow):
976                ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse)
977            else:
978                ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
979            if not ret:
980                raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
981        else:
982            node_manager.erase_node(node)
983        node.set_belong_symbol_tree(None)
984        self._deleted_node.append(node.get_name())
985        return node
986
987    def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
988        """
989        Replace an old_node with a node list.
990
991        Args:
992            old_node (Node): Node to be replaced.
993            new_nodes (list[Node]): Node list to replace in.
994
995        Returns:
996            Last node in new_nodes list.
997
998        Raises:
999            RuntimeError: If 'old_node' is isolated.
1000            RuntimeError: If 'old_node' is not belong to current SymbolTree.
1001        """
1002        real_old_node = self._get_real_node(old_node)
1003        if real_old_node is None:
1004            raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
1005        # insert new_nodes into node_manager
1006        node_manager = real_old_node.get_node_manager()
1007        # insert new_nodes into NodeManager
1008        base_node = old_node
1009        for node in new_nodes:
1010            self.insert_node(node, base_node, False, node_manager, True)
1011            base_node = node
1012        self.erase_node(old_node)
1013        return new_nodes[-1]
1014
1015    def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
1016        """
1017        Set argument of 'node'.
1018
1019        Args:
1020            node (Union[Node, str]): Node to be modified. Can be a node or name of node.
1021            index (int): Indicate which input being modified.
1022            arg (Union[ScopedValue, str]): New argument to been set.
1023
1024        Raises:
1025            RuntimeError: If 'node' is not belong to current SymbolTree.
1026        """
1027
1028        real_node = self._get_real_node(node)
1029        if real_node is None:
1030            raise RuntimeError("Node is not belong to current SymbolTree: ", node)
1031
1032        new_arg, old_arg = node.set_arg(arg, index)
1033        node.get_node_manager().on_update_arg(node, index, old_arg, new_arg)
1034
1035    def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
1036                             out_idx: Optional[int] = None):
1037        """
1038        Set argument of 'dst_node' by another Node.
1039
1040        Args:
1041            dst_node (Node): Node to be modified. Can be a node or name of node.
1042            arg_idx (int): Indicate which input being modified.
1043            src_node (Node): Node as new input. Can be a node or name of node.
1044            out_idx ([int, optional]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None
1045                which means use first output of 'node_to_link' as new input.
1046
1047        Raises:
1048            RuntimeError: If 'dst_node' is not belong to current SymbolTree.
1049            RuntimeError: If 'src_node' is not belong to current SymbolTree.
1050            RuntimeError: If 'out_idx' is out of range.
1051            RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered.
1052        """
1053
1054        real_dst_node = self._get_real_node(dst_node)
1055        if real_dst_node is None:
1056            raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node)
1057        real_src_node = self._get_real_node(src_node)
1058        if real_src_node is None:
1059            raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node)
1060
1061        targets = real_src_node.get_targets()
1062        if out_idx is None:
1063            if len(targets) != 1:
1064                raise RuntimeError("node should has one output when out_idx is not provided")
1065            out_idx = 0
1066        if out_idx >= len(targets):
1067            raise RuntimeError("out_idx out of range: ", out_idx)
1068        new_arg = targets[out_idx]
1069        real_dst_node.set_arg(new_arg, arg_idx)
1070        real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
1071
1072    def unique_name(self, name: str):
1073        """Get a unique name in the symboltree"""
1074        return self._target_namer.get_name(name)
1075
1076    def unique_func_name(self, name: str):
1077        """Get a unique function name in the symboltree"""
1078        if not hasattr(self._origin_network, name):
1079            return name
1080        suffix = 1
1081        while hasattr(self._origin_network, f"{name}_{suffix}"):
1082            suffix += 1
1083        return f"{name}_{suffix}"
1084
1085    def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
1086        """
1087        Set target of `node` .
1088
1089        Args:
1090            node (Union[Node, str]): Node to be modified. Can be a node or name of node.
1091            index (int): Indicate which target being modified.
1092            arg (Union[ScopedValue, str]): New target to been set.
1093
1094        Raises:
1095            ValueError: If `node` is not belong to current SymbolTree.
1096            ValueError: If index of `node` 's target is greater than number of targets.
1097        """
1098
1099        real_node = self._get_real_node(node)
1100        if real_node is None:
1101            raise ValueError("Node is not belong to current SymbolTree: ", node)
1102        if isinstance(target, str):
1103            target = ScopedValue.create_naming_value(target)
1104        targets = node.get_targets()
1105        if index >= len(targets):
1106            raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}")
1107        old_target = targets[index]
1108        targets[index] = target
1109        node.set_targets(targets)
1110        self._topo_mgr.on_update_target(node, index, old_target, target)
1111
1112    def all_nodes(self, subtree_nodes: bool = True):
1113        """
1114        Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
1115
1116        Args:
1117            subtree_nodes (bool): Whether include nodes in subtree. Default: True.
1118
1119        Returns:
1120            A list of nodes.
1121        """
1122        nodes = []
1123        node_managers = [self]
1124        while node_managers:
1125            node_manager = node_managers.pop()
1126            nodes.extend(node_manager.nodes())
1127            for node in node_manager.nodes():
1128                if isinstance(node, NodeManager):
1129                    node_managers.append(node)
1130        if subtree_nodes:
1131            for tree_node in self.get_tree_nodes():
1132                stree = tree_node.symbol_tree
1133                nodes.extend(stree.all_nodes())
1134        return nodes
1135
1136    def get_node_from_name(self, node_name: str):
1137        """
1138        Get node from all NodeManagers in current symbol tree by `node_name`.
1139
1140        Args:
1141            node_name (str): A str represents name of node as key of query.
1142
1143        Returns:
1144            An instance of Node if found else None.
1145        """
1146        node_managers = [self]
1147        while node_managers:
1148            node_manager = node_managers.pop()
1149            node = node_manager.get_node(node_name)
1150            if node:
1151                return node
1152            for node in node_manager.nodes():
1153                if isinstance(node, NodeManager):
1154                    node_managers.append(node)
1155        return None
1156
1157    def get_node_tabulate(self, all_nodes: bool = False) -> str:
1158        """
1159        Get nodes information and nodes' topological relations.
1160
1161        Args:
1162            all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
1163                nodes, CellContainer nodes and sub symbol trees.
1164
1165        Returns:
1166            String of nodes' information and topological relations.
1167        """
1168        try:
1169            from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
1170        except ImportError:
1171            logger.warning("print_node_tabulate relies on the library `tabulate`, "
1172                           "which could not be found on this machine. Run `pip "
1173                           "install tabulate` to install the library.")
1174            return ""
1175        dump_str = NodeManager.dump(self, self.get_manager_name())
1176        if all_nodes:
1177            node_managers = [self]
1178            while node_managers:
1179                node_manager = node_managers.pop()
1180                for node in node_manager.nodes():
1181                    if isinstance(node, NodeManager):
1182                        dump_str += node.dump(SymbolTree.get_node_full_name(node))
1183                        node_managers.append(node)
1184            for tree_node in self.get_tree_nodes():
1185                stree = tree_node.symbol_tree
1186                dump_str += stree.get_node_tabulate(all_nodes)
1187        return dump_str
1188
1189    def dump(self):
1190        """Dump graph."""
1191        dump_st = SymbolTreeDumper(self)
1192        dump_st.dump()
1193
1194    def check_body_exist(self, body, code_bodies):
1195        """Check whether body already exist in code_bodies"""
1196        # Check import ast node exist by saving import code string to self._tmp_import_strs
1197        if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
1198            import_str = astunparse.unparse(body)
1199            if import_str in self._tmp_import_strs:
1200                return True
1201            self._tmp_import_strs.append(import_str)
1202            return False
1203
1204        # Check ClassDef ast node exist by using AstClassFinder
1205        if isinstance(body, ast.ClassDef):
1206            if sys.version_info >= (3, 9):
1207                class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1208            else:
1209                class_finder = AstClassFinder(ast.Module(body=code_bodies))
1210            results = class_finder.find_all(body.name)
1211            return bool(results)
1212
1213        # Check FunctionDef ast node exist by using AstFunctionFinder
1214        if isinstance(body, ast.FunctionDef):
1215            if sys.version_info >= (3, 9):
1216                function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
1217            else:
1218                function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
1219            results = function_finder.find_all(body.name)
1220            return bool(results)
1221
1222        return False
1223
1224    def deduplicate_unmodified_stree(self, code_bodies):
1225        """
1226        Init function may be different even if stree is not modified manually, when subnets in stree is
1227        initialized by different arguments.
1228        In this case, we need to wait for code_bodies being fully generated, so that the name of subnets
1229        will be updated, then we can deduplicate again according to ast of init function.
1230        """
1231        # prepare AstClassFinder and AstReplacer
1232        if sys.version_info >= (3, 9):
1233            class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1234            name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1235        else:
1236            class_finder = AstClassFinder(ast.Module(body=code_bodies))
1237            name_replacer = AstReplacer(ast.Module(body=code_bodies))
1238        # deduplicate all unmodified strees in self._tmp_unmodified_strees
1239        deduplicated = False
1240        for _, unmodified_strees in self._tmp_unmodified_strees.items():
1241            if len(unmodified_strees) <= 1:
1242                continue
1243            init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees]
1244            # If the index of an element is not its own, it means that it is a duplicate element
1245            to_be_erase = []
1246            for idx, code in enumerate(init_func_codes):
1247                first_idx = init_func_codes.index(code)
1248                if first_idx != idx:
1249                    first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name()
1250                    duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name()
1251                    logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.")
1252                    # delete duplicated class from code_bodies
1253                    results = class_finder.find_all(duplicated_stree_cls_name)
1254                    for ast_cls in results:
1255                        code_bodies.remove(ast_cls)
1256                    # replace name of duplicated class in code_bodies to first_stree_cls_name
1257                    name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name)
1258                    # record deduplicated stree
1259                    to_be_erase.append(idx)
1260                    deduplicated = True
1261            # remove class in self._tmp_unmodified_strees
1262            for idx in reversed(to_be_erase):
1263                unmodified_strees.pop(idx)
1264
1265        # the name of subnets is updated, so we need to deduplicate again.
1266        if deduplicated:
1267            self._tmp_replacers.append(name_replacer)
1268            self.deduplicate_unmodified_stree(code_bodies)
1269
1270    def update_unmodified_stree(self, stree, code_bodies) -> bool:
1271        """
1272        For the unmodified symbol tree, only one definition code remains in the generated code.
1273        Everywhere else calling this symbol tree will use the class in this definition code.
1274        """
1275        # all modified ast.ClassDef will be exported to code
1276        if stree.is_modified():
1277            logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.")
1278            return False
1279        # all un-modified ast.ClassDef only keep one instance
1280        unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1281        if not unmodified_strees:
1282            self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree]
1283            logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.")
1284            return False
1285        # Init function may be different even if stree is not modified, when subnets in stree is
1286        # initialized by different arguments.
1287        first_stree = unmodified_strees[0]
1288        first_stree_cls_name = first_stree.get_opt_cls_name()
1289        if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()):
1290            # init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees
1291            # and deduplicate later
1292            self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree)
1293            logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.")
1294            return False
1295        # Un-modified ast.ClassDef already exist in code_bodies,
1296        # replace class name to class name of first un-modified ast.ClassDef.
1297        if sys.version_info >= (3, 9):
1298            replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1299        else:
1300            replacer = AstReplacer(ast.Module(body=code_bodies))
1301        logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.")
1302        replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name)
1303        self._tmp_replacers.append(replacer)
1304        return True
1305
1306    def init_code_bodies(self, code_bodies: list) -> int:
1307        """Init code bodied"""
1308        # Add basic imports
1309        code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)]))
1310        code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)]))
1311        code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
1312        code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0))
1313        code_bodies.append(ast.ImportFrom(module='mindspore.ops',
1314                                          names=[ast.alias(name='functional', asname='F')], level=0))
1315        code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1316        # Add user custom codes into code_bodies
1317        custom_codes = self.get_custom_codes()
1318        for code_ast in custom_codes:
1319            code_bodies.append(code_ast)
1320        code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1321        return len(code_bodies)
1322
1323    def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int:
1324        """
1325        Convert nodes in stree to code_bodies
1326        - Add external function asts into code_bodies
1327        - Add father class asts into code_bodies
1328        - Add import asts of symbol tree into code_bodies
1329        - Add user custom codes into code_bodies
1330        - Add class asts of symbol tree into code_bodies
1331        - Add subtrees to code_bodies
1332        """
1333        insert_pos = dividing_pos
1334        # Add external asts into code_bodies
1335        for ast_func, import_asts in reversed(stree.get_external_ast().items()):
1336            if self.check_body_exist(ast_func, code_bodies):
1337                continue
1338            # add imports of external_ast
1339            self._tmp_import_strs.clear()
1340            for ast_import in import_asts:
1341                if not self.check_body_exist(ast_import, code_bodies):
1342                    code_bodies.insert(insert_pos, ast_import)
1343                    insert_pos += 1
1344            # add external_ast
1345            code_bodies.insert(insert_pos, ast_func)
1346            insert_pos += 1
1347            # add divide
1348            code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1349            insert_pos += 1
1350
1351        # Add father class asts into code_bodies
1352        for ast_class, import_asts in stree.get_father_class_ast().items():
1353            if self.check_body_exist(ast_class, code_bodies):
1354                continue
1355            # add imports of father class
1356            self._tmp_import_strs.clear()
1357            for ast_import in import_asts:
1358                if not self.check_body_exist(ast_import, code_bodies):
1359                    code_bodies.insert(insert_pos, ast_import)
1360                    insert_pos += 1
1361            # add ast of father class
1362            code_bodies.insert(insert_pos, ast_class)
1363            insert_pos += 1
1364            # add divide
1365            code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1366            insert_pos += 1
1367
1368        # external functions and father class are above the dividing_pos to support deduplication.
1369        dividing_pos = insert_pos
1370
1371        # Add import asts of symbol tree into code_bodies
1372        self._tmp_import_strs.clear()
1373        for body in stree.get_import_asts():
1374            if not self.check_body_exist(body, code_bodies):
1375                code_bodies.insert(insert_pos, body)
1376                insert_pos += 1
1377
1378        # Add class asts of symbol tree into code_bodies
1379        if stree.get_module_ast():
1380            for body in stree.get_module_ast().body:
1381                if self.check_body_exist(body, code_bodies):
1382                    continue
1383                code_bodies.insert(insert_pos, body)
1384                insert_pos += 1
1385
1386        # add divide
1387        code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1388        insert_pos += 1
1389
1390        # Add subtrees to code_bodies
1391        for node in stree.get_tree_nodes():
1392            sub_stree = node.symbol_tree
1393            # For the unmodified class, update class name to name of first class
1394            if self.update_unmodified_stree(sub_stree, code_bodies):
1395                continue
1396            dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos)
1397
1398        # return new dividing position
1399        return dividing_pos
1400
1401    def get_code(self) -> str:
1402        """
1403        Get source code of modified network.
1404
1405        Returns:
1406            A str represents source code of modified network.
1407        """
1408        self._tmp_import_strs.clear()
1409        self._tmp_unmodified_strees.clear()
1410        self._tmp_replacers.clear()
1411        code_bodies = []
1412        begin_pos = self.init_code_bodies(code_bodies)
1413        self.convert_stree_to_code_bodies(self, code_bodies, begin_pos)
1414        self.deduplicate_unmodified_stree(code_bodies)
1415        if sys.version_info >= (3, 9):
1416            gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1417        else:
1418            gencode_module = ast.Module(body=code_bodies)
1419        SymbolTree._remove_unused_import(gencode_module)
1420        self._process_duplicate_name_modules(gencode_module)
1421        SymbolTree._remove_duplicated_import(gencode_module)
1422        SymbolTree._remove_arg_annotations(gencode_module)
1423        ast.fix_missing_locations(self._module_ast)
1424        code = astunparse.unparse(gencode_module)
1425        # Revert the class name to its original state
1426        for replacer in self._tmp_replacers:
1427            replacer.undo_all()
1428        return code
1429
1430    def get_network(self):
1431        """
1432        Get modified network.
1433
1434        Returns:
1435            A network object.
1436        """
1437        cls = self._get_cls_through_file()
1438        new_net = cls(self._origin_network)
1439        self._merge_origin_property(new_net)
1440        # update parameters' names to fix duplicated names bug
1441        # which occurs after inserting cell to celllist/sequentialcell
1442        new_net.update_parameters_name()
1443        return new_net
1444
1445    def set_saved_file_name(self, file_name: str):
1446        if file_name.endswith(".py"):
1447            self._saved_file_name = file_name
1448        else:
1449            self._saved_file_name = file_name + ".py"
1450
1451    def get_saved_file_name(self):
1452        return self._saved_file_name
1453
1454    def save_network_to_file(self):
1455        abs_path = os.path.abspath(self._saved_file_name)
1456        if os.path.isfile(abs_path):
1457            os.remove(abs_path)
1458        with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1459            source = self.get_code()
1460            f.write(source.encode('utf-8'))
1461            f.flush()
1462
1463
1464    def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False):
1465        """Flatten nodes in ControlFlow node."""
1466        if not isinstance(node, ControlFlow):
1467            raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.")
1468        upper_node_manager = node.get_node_manager()
1469        if isinstance(upper_node_manager, (SymbolTree, CallFunction)):
1470            ast_bodies = upper_node_manager.get_manager_ast().body
1471        elif isinstance(upper_node_manager, ControlFlow):
1472            ast_bodies = upper_node_manager.get_manager_ast()
1473        else:
1474            raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, "
1475                             f"ControlFlow], but the node is in {type(upper_node_manager)}.")
1476        base_node = node.orelse_node if node.orelse_node else node.body_node
1477        for n in node.nodes()[:]:
1478            self.erase_node(n)
1479            self.insert_node(n, base_node, False, upper_node_manager, False)
1480            AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False)
1481            base_node = n
1482        self.erase_node(node)
1483        # remove another branch
1484        if erase_another_branch:
1485            if node.is_orelse:
1486                self.erase_node(node.body_node)
1487            elif node.orelse_node is not None:
1488                self.erase_node(node.orelse_node)
1489        # remove nodes after return node
1490        if erase_nodes_after_return:
1491            has_return = False
1492            for n in upper_node_manager.nodes():
1493                if has_return:
1494                    logger.warning(f"Node {n.get_name()} which is behind the flatten return node is "
1495                                   f"automatically erased.")
1496                    self.erase_node(n)
1497                elif n.get_node_type() == NodeType.Output:
1498                    has_return = True
1499
1500    def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool):
1501        """
1502        Eval ast_node and get result, only used in control flow node.
1503        """
1504        # ast.Constant can be check without eval
1505        if isinstance(ast_node, ast.Constant):
1506            return True, bool(ast.value)
1507        # Get the module where the code of ast_node is located
1508        file_path = inspect.getfile(type(self.get_origin_network()))
1509        module = None
1510        for m in list(sys.modules.values()):
1511            if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path):
1512                module = m
1513                break
1514        if not module:
1515            logger.warning("Failed to get module of ast_node.")
1516            return False, False
1517        # eval ast_node and get result
1518        logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}")
1519        ast_expr = ast.Expression(ast_node)
1520        ast_expr = ast.fix_missing_locations(ast_expr)
1521        try:
1522            # eval with ast make this operation free of instruction injection
1523            # pylint: disable=eval-used
1524            result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals())
1525        except Exception as e: # pylint: disable=broad-except
1526            logger.debug(f"Cannot get result of ast_node by eval, err:{e}")
1527            return False, False
1528        logger.debug(f"Eval ast result success, result: {result}")
1529        return True, bool(result)
1530
1531    def flatten_static_if_control_flow(self):
1532        """
1533        For static if control flow, flatten codes in branch which will be executed and erase another branch.
1534        """
1535        for node in self.all_nodes()[:]:
1536            if not node.get_belong_symbol_tree():
1537                # the node has been erased
1538                continue
1539            if isinstance(node, ControlFlow) and node.test_result is not None:
1540                stree = node.get_belong_symbol_tree()
1541                if node.test_result:
1542                    stree.flatten_nodes(node.body_node, True, True)
1543                else:
1544                    if node.orelse_node is not None:
1545                        stree.flatten_nodes(node.orelse_node, True, True)
1546                    else:
1547                        stree.erase_node(node.body_node)
1548
1549    def add_custom_codes(self, code: str):
1550        """Add user custom codes"""
1551        code_ast = ast.parse(code)
1552        self._custom_codes.extend(code_ast.body)
1553
1554    def get_custom_codes(self) -> List[ast.AST]:
1555        """Add user custom codes"""
1556        return self._custom_codes
1557
1558    def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None):
1559        """
1560        Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
1561
1562        When level_num = 0(e.g. from xxx import yyy), current path will be saved.
1563        When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
1564        When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
1565        """
1566        file_path = os.path.dirname(os.path.abspath(file_path))
1567        file_path = os.path.normcase(file_path)
1568        file_path = os.path.normpath(file_path)
1569        if level_num > 1:
1570            for _ in range(level_num - 1):
1571                file_path = os.path.dirname(file_path)
1572        sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
1573        # add imports to import_asts of belonging_ast
1574        import_asts = self._get_imports_list_of_ast(belonging_ast)
1575        import_asts.append(ast.Import([ast.alias(name='sys', asname=None)]))
1576        import_asts.append(sys_path_append_ast)
1577
1578    def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None):
1579        """Save imports from file"""
1580        self.save_file_path_to_sys(0, file_path, belonging_ast)
1581        if not os.path.exists(file_path):
1582            raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
1583        with open(file_path, "r", encoding="utf-8") as f:
1584            source_code = f.read()
1585            import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node()
1586        if not import_nodes:
1587            return
1588        # add imports to import_asts of belonging_ast
1589        import_asts = self._get_imports_list_of_ast(belonging_ast)
1590        for import_node in import_nodes:
1591            import_node = SymbolTree._process_relative_import(import_node, file_path)
1592            if import_node:
1593                import_asts.append(import_node)
1594
1595    def add_import(self, module: types.ModuleType, name: str, belonging_ast: None):
1596        """add codes: from `module` import `name`"""
1597        if not isinstance(module, types.ModuleType):
1598            raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}")
1599        if not hasattr(module, name):
1600            logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.")
1601            return
1602        # add imports to import_asts of belonging_ast
1603        import_asts = self._get_imports_list_of_ast(belonging_ast)
1604        if module.__name__ == "__main__":
1605            # get attr from module instead of import to avoid duplicate execution of __main__ module
1606            code = f"{name} = getattr(sys.modules['__main__'], '{name}')"
1607            code_ast = ast.parse(code).body[0]
1608            import_asts.append(code_ast)
1609        elif module.__name__ == "builtins":
1610            # built-in functions are not need to be imported
1611            pass
1612        else:
1613            # add import of obj to ast
1614            func_file_path = inspect.getabsfile(module)
1615            func_file_path = os.path.normcase(func_file_path)
1616            prefix_paths = []
1617            for path in sys.path:
1618                path = os.path.normcase(path)
1619                if func_file_path.startswith(path):
1620                    prefix_paths.append(path)
1621            prefix_paths.sort(key=len, reverse=True)
1622            for path in prefix_paths:
1623                import_path = func_file_path[len(path):]
1624                import_str = import_path.replace(os.path.sep, '.')
1625                import_str = import_str[1:] # remove first '.'
1626                mod = import_str.rsplit('.', 1)[0]
1627                if SymbolTree._check_import(func_file_path[:len(path)], mod):
1628                    import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1629                    import_asts.append(import_node)
1630                    break
1631            else:
1632                self.save_file_path_to_sys(0, func_file_path, belonging_ast)
1633                mod = os.path.basename(func_file_path).rsplit('.')[0]
1634                import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1635                import_asts.append(import_node)
1636
1637    def _get_imports_list_of_ast(self, belonging_ast: ast.AST):
1638        # get import_asts of belonging_ast
1639        import_asts = self._import_asts
1640        if belonging_ast is not None:
1641            if belonging_ast in self._father_class_ast:
1642                import_asts = self._father_class_ast.get(belonging_ast)
1643            elif belonging_ast in self._external_ast:
1644                import_asts = self._external_ast.get(belonging_ast)
1645        return import_asts
1646
1647    def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1648        if isinstance(node_or_name, str):
1649            return self.get_node(node_or_name)
1650        return node_or_name
1651
1652    def _handle_custom_obj_in_normalized_args(self, node: Node):
1653        """
1654        Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
1655
1656        Args:
1657            node (Node): A Node whose arguments and keyword arguments to be handled.
1658        """
1659        normalized_args: {str, ScopedValue} = {}
1660        for key, value in node.get_normalized_args().items():
1661            if not isinstance(value, ScopedValue):
1662                raise TypeError("value should be ScopedValue, got: ", type(value))
1663            if value.type == ValueType.CustomObjValue:
1664                # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
1665                arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
1666                setattr(self._origin_network, arg_name, value.value)
1667                # Add new code to __init__(): self.arg_name = obj.arg_name
1668                new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
1669                self._init_func_ast.body.append(new_ast)
1670                # Modify node's normalized_args: CustomObjValue -> self.arg_name
1671                normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
1672            else:
1673                normalized_args[key] = value
1674        node.set_normalized_args(normalized_args)
1675
1676    def _get_cls_through_file(self):
1677        """
1678        Load rewritten network class of current SymbolTree.
1679        1. Get source code of current SymbolTree.
1680        2. Saving source code to a tempfile.
1681        3. Import rewritten network class using "__import__" function.
1682
1683        Returns:
1684            A class handle.
1685        """
1686        file_path = os.getcwd()
1687        file_path = os.path.join(file_path, "rewritten_network")
1688        if not os.path.exists(file_path):
1689            try:
1690                os.mkdir(file_path, mode=0o700)
1691            except FileExistsError:
1692                pass
1693        file_name = f"{self._opt_cls_name}_{id(self)}.py"
1694        network_file = os.path.join(file_path, file_name)
1695        with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1696            source = self.get_code()
1697            f.write(source.encode('utf-8'))
1698            f.flush()
1699            os.fsync(f)
1700        tmp_module_path, tmp_module_file = os.path.split(network_file)
1701        tmp_module_name = tmp_module_file[:-3]
1702        sys.path.append(tmp_module_path)
1703        tmp_module = None
1704
1705        i = 0
1706        while not tmp_module:
1707            spec = importlib.util.spec_from_file_location(tmp_module_name, network_file)
1708            if spec:
1709                tmp_module = importlib.util.module_from_spec(spec)
1710                spec.loader.exec_module(tmp_module)
1711            else:
1712                logger.warning(f"load module {tmp_module_name} failed, retrying.")
1713                if i > 10:
1714                    break
1715                time.sleep(0.5)
1716                i += 1
1717        if not tmp_module:
1718            raise ImportError(f"load module {tmp_module_name} failed.")
1719        # Save new module to sys.modules to support inspect.getsource().
1720        sys.modules[tmp_module_name] = tmp_module
1721        network_cls = getattr(tmp_module, self._opt_cls_name)
1722        if network_cls is None:
1723            raise RuntimeError("Can not find network class:", self._opt_cls_name)
1724        return network_cls
1725
1726    def _on_change(self, event: Event):
1727        self._modified = True
1728        self.changed(event)
1729
1730    def _cal_difference_set(self, input, other):
1731        """Calculate different set of two sets."""
1732        set1 = set(input)
1733        set2 = set(other)
1734        return set1 - set2
1735
1736    def _merge_origin_property(self, new_net):
1737        """Merge property of two network."""
1738        tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
1739        new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
1740        for name in new_attr_names:
1741            setattr(new_net, name, getattr(self._origin_network, name))
1742        # merger cells
1743        cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
1744        cells = self._cal_difference_set(cells, self._deleted_node)
1745        for c in cells:
1746            new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
1747        # merge primitives
1748        # pylint: disable=protected-access
1749        primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1750        for p in primitives:
1751            new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access
1752
1753    def _process_duplicate_name_modules(self, module_ast: ast.Module):
1754        """Adjust names of imported modules with same name and different import path."""
1755        # {name1: [path1, path2, ...], ...}
1756        name_path_dict: Dict[str, List[str]] = {}
1757        # names of modules need to be suffixed: {name1: suffixed_name1, ...}
1758        name_need_suffix: Dict[str, str] = {}
1759        # used to record replace actions in ast.ImportFrom
1760        import_replacer = AstReplacer(None)
1761        self._tmp_replacers.append(import_replacer)
1762
1763        def suffix_alias(alias: ast.alias, suffix: int):
1764            """suffix the name of alias in ast.ImportFrom"""
1765            new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}"
1766            import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access
1767            alias.asname = new_name
1768            return new_name
1769
1770        def is_divider(ast_node):
1771            """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
1772            return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
1773
1774        def record_imports(ast_node: ast.ImportFrom):
1775            """record name and path of imported modules to find the duplicate name modules."""
1776            for alias in ast_node.names[:]:
1777                name = alias.asname if alias.asname else alias.name
1778                if name == '*':
1779                    continue
1780                # current name is firstly imported, just record it
1781                if name not in name_path_dict:
1782                    name_path_dict[name] = [ast_node.module]
1783                    continue
1784                # current name is imported before, check whether it is a duplicated name
1785                for idx, path in enumerate(name_path_dict[name]):
1786                    if path.startswith(ast_node.module):
1787                        # e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A'
1788                        # then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx}
1789                        name_path_dict[name][idx] = ast_node.module
1790                        if idx > 0:
1791                            name_need_suffix[name] = suffix_alias(alias, idx)
1792                        break
1793                    elif ast_node.module.startswith(path):
1794                        # e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A'
1795                        # then we just need to update name to A_{idx}
1796                        if idx > 0:
1797                            name_need_suffix[name] = suffix_alias(alias, idx)
1798                        break
1799                else:
1800                    # current name is imported from a new path, save the path and update the name
1801                    name_path_dict[name].append(ast_node.module)
1802                    name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1)
1803
1804        def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]):
1805            """suffix names in ast.ClassDef or ast.FunctionDef"""
1806            if not name_need_suffix:
1807                return
1808            name_replacer = AstReplacer(ast_node)
1809            self._tmp_replacers.append(name_replacer)
1810            for name, new_name in name_need_suffix.items():
1811                name_replacer.replace_all(name, new_name)
1812
1813        for ast_node in module_ast.body:
1814            if isinstance(ast_node, ast.ImportFrom):
1815                record_imports(ast_node)
1816            if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
1817                suffix_names_in_ast(ast_node)
1818            if is_divider(ast_node):
1819                name_need_suffix.clear()
1820