• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Node class define of Rewrite. See detail in Node class docstring."""
16from typing import Optional, Union, List, Dict
17import ast
18import inspect
19from types import FunctionType
20import sys
21
22from mindspore.nn import Cell
23from mindspore.ops import Primitive
24from mindspore import log as logger
25from ..api.scoped_value import ScopedValue, ValueType
26from ..api.node_type import NodeType
27from ..common.namespace import is_subtree
28from ..common.error_log import error_str
29from ..ast_helpers import AstModifier, AstReplacer, AstConverter
30from ... import _checkparam as Validator
31
32
33if sys.version_info >= (3, 9):
34    import ast as astunparse # pylint: disable=reimported, ungrouped-imports
35else:
36    import astunparse
37
38
39class LocalPrim(Primitive):
40    """This class is used to indicate a local primitive instance"""
41    def __init__(self, prim_obj: type):
42        super().__init__("rewrite_local_prim")
43        self.prim_obj = prim_obj
44
45
46class Node:
47    """
48    Node is a data structure represents a source code line in network. For the most part, Node represents an operator
49    invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
50    Node has different meaning in different type of node:
51
52    - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
53      `targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
54      `kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
55      method. `func` is corresponding to func of call expression which means symbol of the cell-op.
56    - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
57      `targets`, `args`, `kwargs` and `func_name` are as previous.
58    - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
59      `targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
60      represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
61      method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
62      mapped to CallMethod node whose `func_name` is "PassThrough".
63    - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
64      supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
65    - Input: an input node represents an input of current network which also a parameter of forward method of Cell.
66      `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
67      of forward function. `kwargs` and `func_name` are don't-care.
68    - Output: an output node represents the output of current network which is corresponding to return statement of
69      forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
70      `kwargs` are don't-care.
71    - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
72      `targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
73      SymbolTree instance.
74    """
75
76    def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
77                 func_name: Optional[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue], name: str,
78                 instance):
79        """
80        Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
81        as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
82        `create_output_node`, etc. rather than invoking constructor of Node directly.
83
84        Args:
85            node_type (NodeType): A NodeType as type of Node.
86            ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
87                not be None except when node type is Unknown.
88            targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
89            func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
90            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
91            kwargs (Dict[str, ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
92            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
93                Name of node also used as field name in network class.
94            instance: Object in network corresponding to this node.
95        """
96        self._node_type: NodeType = node_type
97        self._ast_node: Optional[ast.AST] = ast_node
98        self._attribute: {str, object} = {}
99        if node_type in (NodeType.CallModule, NodeType.CallCell, NodeType.CallPrimitive):
100            self._attribute = Node._get_cell_or_prim_op_attribute(instance)
101        self._instance = instance
102        self._name = name
103        self._func_name: Optional[ScopedValue] = func_name
104        self._targets: [ScopedValue] = targets if targets is not None else []
105        self._args_num = len(args) if args is not None else 0
106        self._kwargs_num = len(kwargs) if kwargs is not None else 0
107        self._normalized_args_keys = []  # for saving args' order
108        self._normalized_args = self._get_normalized_args(args, kwargs)
109        # position in graph nodes list
110        # it will affect code-order of python code
111        self._prev: Optional[Node] = None
112        self._next: Optional[Node] = None
113        # A handler of SymbolTree current node belonging to
114        self._belong_tree = None
115        # A handler of NodeManager current node belonging to
116        self._node_manager = None
117        # A dict that records which target of which Node current Node's argument come from
118        self._arg_providers: {int: (Node, int)} = {}
119        # A dict that records which argument of which Node uses current Node's target
120        self._target_users: {int: [(Node, int)]} = {}
121        # Indicate this node represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
122        self._type_cls = None
123        # Indicate this node represent the initialize of a class type, e.g. abs_inst = P.Abs()
124        self._init_cls = None
125
126    @classmethod
127    def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
128                           func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
129                           kwargs: {str: ScopedValue}=None, name: str = ""):
130        """
131        Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
132        invoking to cell-op.
133
134        Args:
135            ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
136                should not be None currently.
137            targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
138            func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
139            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
140            kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
141            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
142                Name of node also used as field name in network class.
143        """
144        if args is None:
145            args = []
146        if kwargs is None:
147            kwargs = {}
148        if isinstance(func_name, str):
149            func_name = ScopedValue.create_naming_value(func_name)
150        new_targets = Node._handle_targets(targets)
151        if ast_node is None:
152            raise RuntimeError("Input ast_node is None")
153        return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
154
155    @classmethod
156    def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None):
157        """
158        Class method of Node. Instantiate an instance of node whose type is Python. A Python node represents some python
159        statement is not supported by Rewrite or ignored by Rewrite.
160
161        Args:
162            ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
163            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
164                Name of node also used as field name in network class.
165            instance: An object corresponding to this node in network.
166        """
167        return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
168
169    @classmethod
170    def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
171                          name: str = ""):
172        """
173        Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
174        SymbolTree which is corresponding to parameters of forward function.
175
176        Args:
177            ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
178            arg_name (str): A string represents name of parameter.
179            default ([ScopedValue, optional]): An instance of ScopedValue represents default value of parameter.
180            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
181                Name of node also used as field name in network class.
182        """
183        target = ScopedValue.create_naming_value(arg_name)
184        if default is None:
185            args = []
186        else:
187            args = [default]
188        if ast_node is None:
189            ast_node = ast.arg(arg_name, annotation="")
190        return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
191
192    @classmethod
193    def create_output_node(cls, ast_node: ast.AST, return_value: [ScopedValue], name: str = "return"):
194        """
195        Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of
196        SymbolTree which is corresponding to return statement of forward function.
197
198        Args:
199            ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
200            return_values (list[str]): A list of string represents name of return values.
201            name (ScopedValue): An instance of ScopedValue represents name of node.
202        """
203        return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), return_value, {},
204                   name, None)
205
206    @classmethod
207    def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
208                            op_type: ScopedValue, args: [ScopedValue], name: str = ""):
209        """
210        Class method of Node. Instantiate an instance of node whose type is `MathOps` .
211        A mathops node is used to represent a node with mathematical operations, such as
212        `y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
213
214        Args:
215            ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
216                node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
217                ast.Compare.
218            targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
219                See detail in docstring of Node class.
220            op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
221            args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
222                sequentially in the list.
223            name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
224                Name of node also used as field name in network class. The format of mathops node name
225                is 'AstNodeName_AstOpName_n'.
226        """
227        return cls(NodeType.MathOps, ast_node, targets, op_type, args, None, name, None)
228
229    @staticmethod
230    def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
231                              kwargs: {str: ScopedValue}=None):
232        """
233        Create a node that corresponds to a function call.
234
235       Args:
236            function (FunctionType): The function to be called.
237            targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
238            args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
239                source code. Default: ``None`` , which indicates the `function` has no args inputs.
240            kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
241                Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
242                code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
243
244        Returns:
245            An instance of `Node`.
246        """
247        if args is None:
248            args = []
249        if kwargs is None:
250            kwargs = {}
251        targets = Node._handle_targets(targets)
252        func_name = function.__name__
253        func_scope_name = ScopedValue.create_naming_value(func_name)
254        node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
255        return node
256
257    @classmethod
258    def inner_create_call_function(cls, node_name: str, ast_node: ast.Assign, func_name: ScopedValue, func_obj: object,
259                                   targets: List[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue]):
260        '''
261        Instantiate an instance of node whose type is `CallFunction`.
262
263        Args:
264            node_name (str): Name of node.
265            func_name (ScopedValue): Name of function.
266            ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
267            func_obj (Object): An object of function. See detail in docstring of Node class.
268            targets (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
269            args (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
270            kwargs (Dict[str, ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of `Node`
271                class.
272        '''
273        from . import CallFunction
274        # create CallFunction node
275        return CallFunction(targets, func_name, args, kwargs, node_name, ast_node, None, None, func_obj, False)
276
277    @staticmethod
278    def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
279                       args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
280                       is_sub_net: bool = False):
281        """
282        Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
283        If op is custom defined, it is treated by TreeNode.
284        A `CallCell` node represents an invoking to cell-op.
285        A `CallPrimitive` node represents an invoking to primitive-op.
286
287        Args:
288            op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
289            ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
290            targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
291            args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
292            kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
293                class.
294            node_name (str): A string represents name of node. Name of node will be unique when inserted into
295                `SymbolTree`. Name of node also used as field name in network class.
296            is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
297                `cell` to a TreeNode, else a CallCell Node. Default is a False.
298        """
299        Validator.check_value_type("op", op, [Cell, Primitive], "Node")
300        if ast_node is not None:
301            Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
302        Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
303        if args is not None:
304            Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
305        if kwargs is not None:
306            Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
307        if args is None:
308            args = []
309        if kwargs is None:
310            kwargs = {}
311        Validator.check_value_type("node_name", node_name, [str], "Node")
312        new_targets = Node._handle_targets(targets)
313        if isinstance(node_name, str):
314            func_name = ScopedValue.create_naming_value(node_name)
315        else:
316            func_name = node_name
317        if is_sub_net and is_subtree(op):
318            from ..symbol_tree import SymbolTreeBuilder
319            stb = SymbolTreeBuilder(op)
320            stree = stb.build()
321            replacer = AstReplacer(stree.get_class_ast())
322            replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
323            return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
324
325        return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
326
327    @classmethod
328    def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
329                               func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
330                               node_name: str = ""):
331        """
332        Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
333        A `CallCell` node represents an invoking to cell-op.
334        A `CallPrimitive` node represents an invoking to primitive-op.
335
336        Args:
337            op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
338            ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
339            targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
340            func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
341            args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
342            kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
343                class.
344            node_name (str): A string represents name of node. Name of node will be unique when inserted into
345                `SymbolTree`. Name of node also used as field name in network class.
346        """
347
348        if not isinstance(op, (Cell, Primitive)):
349            raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
350        if isinstance(op, Cell):
351            node_type = NodeType.CallCell
352        else:
353            node_type = NodeType.CallPrimitive
354        return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
355
356    @staticmethod
357    def _get_construct_arg_names(parameters):
358        """
359        Static method of `Node`. Get parameters' names of the construct function.
360
361        Args:
362            parameters (MappingProxyType): An ordered mapping of parameters' names to the corresponding Parameter
363                objects.
364
365        Raises:
366            RuntimeError: Invalid parameter kind.
367
368        Returns:
369            - arg_names, Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD].
370            - var_positional_name, Name of VAR_POSITIONAL parameters.
371            - var_keyword_name, Name of VAR_KEYWORD parameters.
372        """
373        position_only_names: [str] = []
374        positional_or_keyword_names: [str] = []
375        var_positional_name = None
376        keyword_only_names: [str] = []
377        var_keyword_name = None
378        for name, para in parameters.items():
379            if para.kind == inspect.Parameter.POSITIONAL_ONLY:  # parameters which appear before a '/'
380                position_only_names.append(name)
381            elif para.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:  # parameters which appear before '*' or '*args'
382                positional_or_keyword_names.append(name)
383            elif para.kind == inspect.Parameter.VAR_POSITIONAL:  # corresponds to a '*args'
384                var_positional_name = name
385            elif para.kind == inspect.Parameter.KEYWORD_ONLY:  # parameters which appear after '*' and before '**'
386                keyword_only_names.append(name)
387            elif para.kind == inspect.Parameter.VAR_KEYWORD:  # corresponds to a '**kwargs'
388                var_keyword_name = name
389            else:
390                raise RuntimeError("invalid parameter kind:", para.kind)
391        if "self" in position_only_names:
392            position_only_names.remove("self")
393        if "self" in positional_or_keyword_names:
394            positional_or_keyword_names.remove("self")
395        names = (position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names,
396                 var_keyword_name)
397        return names
398
399    @staticmethod
400    def _map_args_names(names: tuple, args: [ScopedValue], kwargs: {str: ScopedValue},
401                        normalized_args_keys: [str], normalized_args: {str: ScopedValue}):
402        """
403        Fill in normalized_args according to the order of parameters of construct func.
404
405        Args:
406            names (tuple): Parameters' name got from construct func.
407            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
408            kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
409            normalized_args (dict{str: ScopedValue}): The normalized args to be filled.
410
411        Raises:
412            RuntimeError: Input args are invalid.
413            RuntimeError: Arg name already exist in kwargs.
414            RuntimeError: Input kwargs invalid.
415        """
416        position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, var_keyword_name = \
417            names
418        for arg_index, arg in enumerate(args):
419            if arg_index < len(position_only_names):
420                arg_key = position_only_names[arg_index]
421            elif arg_index < len(position_only_names) + len(positional_or_keyword_names):
422                arg_key = positional_or_keyword_names[arg_index - len(position_only_names)]
423            elif var_positional_name:
424                arg_key = "{}_{}".format(var_positional_name, arg_index)
425            else:
426                raise RuntimeError("Input args are invalid.")
427
428            if arg_key in kwargs.keys():
429                raise RuntimeError("Arg name already exist in kwargs.")
430            normalized_args[arg_key] = arg
431            normalized_args_keys.append(arg_key)
432
433        # add kwargs according to parameters' order
434        parameters_order: [str] = []
435        parameters_order.extend(position_only_names)
436        parameters_order.extend(positional_or_keyword_names)
437        parameters_order.append(var_keyword_name)
438        parameters_order.extend(keyword_only_names)
439        parameters_order.append(var_keyword_name)
440
441        sorted_kwargs = []
442        var_keyword_count = len(parameters_order)
443        for arg_key, value in kwargs.items():
444            if arg_key not in parameters_order and not var_keyword_name:
445                raise RuntimeError("Input kwargs invalid.")
446            if arg_key in parameters_order:
447                sorted_kwargs.append([arg_key, value, parameters_order.index(arg_key)])
448            else:
449                sorted_kwargs.append([arg_key, value, var_keyword_count])
450                var_keyword_count += 1
451
452        sorted_kwargs.sort(key=lambda x: x[2])
453        for sorted_kwarg in sorted_kwargs:
454            normalized_args[sorted_kwarg[0]] = sorted_kwarg[1]
455            normalized_args_keys.append(sorted_kwarg[0])
456
457    @staticmethod
458    def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]:
459        """
460        Convert CustomObjValue type argument to NamingValue type argument.
461
462        Args:
463            args (list[ScopedValue]): A list of instance of ScopedValue to be converted.
464
465        Returns:
466            A list of instance of ScopedValue which have been converted.
467        """
468        result = []
469        for arg in args:
470            if not isinstance(arg, ScopedValue):
471                raise TypeError("arg should be ScopedValue, got: ", type(arg))
472            if arg.type == ValueType.CustomObjValue:
473                logger.info("custom-object exist in args, should be replace before compile")
474                result.append(ScopedValue.create_naming_value("custom-object", "self"))
475            else:
476                result.append(arg)
477        return result
478
479    @staticmethod
480    def _handle_custom_obj_in_kwargs(kwargs: {str: ScopedValue}) -> {str: ScopedValue}:
481        """
482        Convert CustomObjValue type argument to NamingValue type argument.
483
484        Args:
485            kwargs (dict{str: ScopedValue}): A str to instance of ScopedValue dict whose value to be converted.
486
487        Returns:
488            A str to instance of ScopedValue dict whose value has be converted.
489        """
490        result: {str, ScopedValue} = {}
491        for arg, value in kwargs.items():
492            if not isinstance(value, ScopedValue):
493                raise TypeError("value should be ScopedValue, got: ", type(value))
494            if value.type == ValueType.CustomObjValue:
495                result[arg] = ScopedValue.create_naming_value("custom-object", "self")
496            else:
497                result[arg] = value
498        return result
499
500    @staticmethod
501    def _handle_targets(targets: [Union[ScopedValue, str]]) -> [ScopedValue]:
502        """
503        Normalize targets to be a list of ScopedValue. If target is a str, it will be converted to NamingValue type
504        ScopedValue.
505
506        Args:
507            targets (Union[ScopedValue, str]]): A list whose element could be a ScopedValue or a str to be normalized.
508
509        Returns:
510            A list of instance of ScopedValue which have been converted.
511        """
512        if not isinstance(targets, list):
513            raise TypeError("targets should be list, got: ", type(targets))
514        results = []
515        for target in targets:
516            if isinstance(target, str):
517                scope = ""
518                name = target
519                if target.count('.') > 0:
520                    scope, name = target.rsplit('.', 1)
521                results.append(ScopedValue.create_naming_value(name, scope))
522            elif isinstance(target, ScopedValue):
523                results.append(target)
524            else:
525                raise RuntimeError("Invalid symbol type: ", target)
526        return results
527
528    @staticmethod
529    def _get_cell_or_prim_op_attribute(obj) -> dict:
530        """
531        Find attributes of cell-op or primitive-op.
532
533        Args:
534            obj: A cell-op or a primitive-op.
535
536        Returns:
537            A dict represents attributes of input 'obj'.
538        """
539        attributes = {}
540        if obj is None:
541            return attributes
542        for k, v in obj.__dict__.items():
543            if k.startswith("_"):
544                continue
545            attributes[k] = v
546        attributes["cls"] = obj.__class__
547        return attributes
548
549    def get_type_cls(self) -> object:
550        """Get the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
551        return self._type_cls
552
553    def set_type_cls(self, x):
554        """Set the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
555        self._type_cls = x
556
557    def get_init_cls(self) -> object:
558        """Get the class type object initialized by this node, e.g. abs_inst = P.Abs()"""
559        return self._init_cls
560
561    def set_init_cls(self, x):
562        """Set the class type object initialized by this node"""
563        self._init_cls = x
564
565    def get_prev(self) -> 'Node':
566        """
567        Get previous node of current node in source code order.
568
569        Returns:
570            An instance of Node as previous node.
571        """
572        return self._prev
573
574    def get_next(self) -> 'Node':
575        """
576        Get next node of current node in source code order.
577
578        Returns:
579            An instance of Node as next node.
580        """
581        return self._next
582
583    def set_prev(self, node: 'Node'):
584        """
585        Set previous node of current node.
586
587        Args:
588            node (Node): Node to be set as previous node of current node.
589        """
590        self._prev = node
591
592    def set_next(self, node: 'Node'):
593        """
594        Set next node of current node.
595
596        Args:
597            node (Node): Node to be set as next node of current node.
598        """
599        self._next = node
600
601    def get_ast(self) -> Optional[ast.AST]:
602        """
603        Getter of _ast_node.
604
605        Returns:
606            An instance of ast.AST if self._ast_node if not None else None.
607        """
608        return self._ast_node
609
610    def set_ast(self, ast_node: ast.AST):
611        """
612        Setter of _ast_node.
613
614        Args:
615            ast_node (ast.AST): An instance of ast.AST as new value for _ast_node.
616        """
617        if not isinstance(ast_node, ast.AST):
618            raise TypeError("ast_node should be ast.AST, got: ", type(ast_node))
619        self._ast_node = ast_node
620
621    def get_belong_symbol_tree(self):
622        """Get the symbol tree to which node belongs."""
623        return self._belong_tree
624
625    def set_belong_symbol_tree(self, symbol_tree):
626        """Set the symbol tree to which node belongs."""
627        self._belong_tree = symbol_tree
628
629    def get_node_manager(self):
630        """Get the NodeManager current node belongs to."""
631        return self._node_manager
632
633    def set_node_manager(self, node_manager):
634        """Set NodeManager current node belongs."""
635        self._node_manager = node_manager
636
637    def isolate(self):
638        """Link prev node to next node and isolate node from source code order list."""
639        origin_prev: Optional[Node] = self.get_prev()
640        origin_next: Optional[Node] = self.get_next()
641        if origin_prev is not None:
642            origin_prev.set_next(origin_next)
643        if origin_next is not None:
644            origin_next.set_prev(origin_prev)
645        self.set_prev(None)
646        self.set_next(None)
647
648    def insert_before(self, node: 'Node'):
649        """
650        Insert a node before current node in source code list. Note that topological order is not determined here.
651
652        Args:
653            node (Node): An instance of node to be inserted in.
654        """
655        node.isolate()
656        origin_prev: Optional[Node] = self.get_prev()
657        if origin_prev is not None:
658            origin_prev.set_next(node)
659        node.set_prev(origin_prev)
660        node.set_next(self)
661        self.set_prev(node)
662
663    def insert_after(self, node: 'Node'):
664        """
665        Insert a node after current node in source code list. Note that topological order is not determined here.
666
667        Args:
668            node (Node): An instance of node to be inserted in.
669        """
670        node.isolate()
671        origin_next: Optional[Node] = self.get_next()
672        self.set_next(node)
673        node.set_prev(self)
674        node.set_next(origin_next)
675        if origin_next is not None:
676            origin_next.set_prev(node)
677
678    def get_inputs(self) -> ['Node']:
679        """
680        Get input nodes of current node in topological order.
681
682        Returns:
683            A list of instances of Node as input nodes.
684        """
685        inputs = []
686        for arg_provider in self.get_arg_providers().values():
687            if not arg_provider:
688                continue
689            inputs.append(arg_provider[0])
690        return inputs
691
692    def get_users(self) -> ['Node']:
693        """
694        Get user nodes of current node in topological order.
695
696        Returns:
697            A list of instances of Node as user nodes.
698        """
699        users = []
700        for target_users in self.get_target_users().values():
701            if not target_users:
702                continue
703            for (user, _) in target_users:
704                if user not in users:
705                    users.append(user)
706        return users
707
708    def get_targets(self) -> [ScopedValue]:
709        """
710        Getter of _targets.
711
712        - When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
713          represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
714          ast.Assign.
715        - When node_type of current node is Input, `targets` should have only one element which is a string represents
716          name of parameter of function.
717        - When node_type of current node is Python or Output, `targets` are don't-care.
718
719        Returns:
720            A list of instances of ScopedValue as targets of node.
721        """
722        return self._targets
723
724    def set_targets(self, targets: [ScopedValue]):
725        """
726        Setter of _targets.
727
728        Note:
729            This interface can only be called before node been inserted into symbol-tree because target will be unique
730            while insert into symbol-tree, in other word, set_targets is not a user-interface.
731
732            When `_targets` is updated, corresponding ast node would be updated also.
733
734            When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
735            represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets
736            of ast.Assign.
737
738            When node_type of current node is Input, `targets` should have only one element which is a string represents
739            name of parameter of function.
740
741            When node_type of current node is Python or Output, `targets` are don't-care.
742
743        Args:
744            targets ([ScopedValue]): A list of instances of ScopedValue as new targets.
745        """
746        self._targets = targets
747        if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
748                               NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
749                               NodeType.MathOps):
750            self._sync_assign_targets_to_ast()
751
752    def get_func_name(self) -> ScopedValue:
753        """
754        Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
755
756        Returns:
757            An instance of ScopedValue.
758        """
759        return self._func_name
760
761    def set_func_name(self, func_name: ScopedValue):
762        """
763        Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
764
765        Note:
766            When `_func_name` is updated, corresponding ast node would be updated also.
767
768        Args:
769            func (ScopedValue): An instance of ScopedValue as new func.
770        """
771        self._func_name = func_name
772        if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
773            self._sync_assign_func_name_to_ast()
774
775    def get_name(self) -> str:
776        """
777        Getter of `_name`.
778
779        Returns:
780            A str represents name of node.
781        """
782        return self._name
783
784    def set_name(self, name: str):
785        """
786        Setter of `_name`.
787
788        Args:
789            name (str): A str as new name of node.
790        """
791        self._name = name
792
793    def get_node_type(self) -> NodeType:
794        """
795        Get the node_type of current node.
796
797        Returns:
798            A NodeType as node_type of node.
799        """
800        return self._node_type
801
802    def get_instance_type(self) -> type:
803        """
804        Get the instance_type of current node.
805
806        - When node_type of current node is CallCell, instance_type is type of cell-op.
807        - When node_type of current node is CallPrimitive, instance_type is type of primitive-op.
808        - When node_type of current node is Tree, instance_type is type of network-cell.
809        - When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType
810
811        Returns:
812            A type.
813        """
814        if isinstance(self._instance, LocalPrim):
815            return self._instance.prim_obj
816        if inspect.isfunction(self._instance):
817            return self._instance
818        return type(self._instance)
819
820    def get_instance(self):
821        """
822        Get the instance of current node.
823
824        - When node_type of current node is CallCell, instance is an instance of Cell.
825        - When node_type of current node is CallPrimitive, instance is an instance of primitive.
826        - When node_type of current node is Tree, instance is an instance of network-cell.
827        - When node_type of current node is Python, Input, Output or CallMethod, instance should be None
828
829        Returns:
830            A object.
831        """
832        return self._instance
833
834    def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None):
835        """
836        Set argument by another Node.
837        Note that when _normalized_args is updated, corresponding ast node would be updated also.
838
839        Args:
840            arg_idx (int): Indicate which input being modified.
841            node (Node): Node as new input. Can be a node or name of node.
842            out_idx ([int, optional]): Indicate which output of `node` as new argument. Default is None which means use
843                first output of `node_to_link` as new input.
844
845        Raises:
846            ValueError: If `arg_idx` is out of range.
847            ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
848        """
849        Validator.check_value_type("node", node, [Node], "Node")
850        Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
851        if out_idx is None:
852            if len(node.get_targets()) != 1:
853                raise ValueError("node should has one output when out_idx is not provided")
854            out_idx = 0
855        Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
856        new_arg = node.get_targets()[out_idx]
857        self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
858        self._sync_arg()
859
860    def set_arg(self, arg: Union[ScopedValue, str], index: int) -> (ScopedValue, ScopedValue):
861        """
862        Set argument of `node`.
863        Note that when _normalized_args is updated, corresponding ast node would be updated also.
864
865        Args:
866            index (int): Indicate which input being modified.
867            arg (Union[ScopedValue, str]): New argument to been set.
868
869        Raises:
870            ValueError: If `index` is out of range.
871        """
872        Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
873        Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
874        if isinstance(arg, str):
875            arg = ScopedValue.create_naming_value(arg)
876        old_arg = self._normalized_args.get(self._normalized_args_keys[index])
877        self._normalized_args[self._normalized_args_keys[index]] = arg
878        self._sync_arg()
879        return arg, old_arg
880
881    def set_args(self, args: [ScopedValue]):
882        """
883        Set arguments of `node`.
884        Note that when _normalized_args is updated, corresponding ast node would be updated also.
885
886        Args:
887            args (list[ScopedValue]): New arguments to been set.
888
889        Raises:
890            TypeError: Element of new argument is not an instance of ScopedValue.
891        """
892        Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
893        Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
894        for arg_index, arg in enumerate(args):
895            if not isinstance(arg, ScopedValue):
896                raise TypeError("arg should be ScopedValue, got: ", type(arg))
897            self._normalized_args[self._normalized_args_keys[arg_index]] = arg
898        self._sync_arg()
899
900    def set_kwargs(self, kwargs: {str: ScopedValue}):
901        """
902        Set keywords arguments of 'node'.
903        Note that when _normalized_args is updated, corresponding ast node would be updated also.
904
905        Args:
906            kwargs (dict{str: ScopedValue}): New arguments to been set.
907
908        Raises:
909            TypeError: Value of new argument is not an instance of ScopedValue.
910            RuntimeError: Length of new arguments is not equal to length of old arguments.
911        """
912        Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
913        Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
914        for key, arg in kwargs.items():
915            if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
916                raise RuntimeError("Input key is not exist, ", key)
917            if not isinstance(arg, ScopedValue):
918                raise TypeError("arg should be ScopedValue, got: ", type(arg))
919            self._normalized_args[key] = arg
920        self._sync_arg()
921
922    def set_kwarg(self, key: str, arg: ScopedValue):
923        """
924        Set keyword argument of 'node'.
925        Note that when _normalized_args is updated, corresponding ast node would be updated also.
926
927        Args:
928            key (str): A str represents key of new argument.
929            arg (ScopedValue): An instance of ScopedValue represents argument.
930
931        Raises:
932            RuntimeError: If 'key' is not in original kwargs' keys.
933        """
934        if key not in self._normalized_args_keys[self._args_num:] or key not in self._normalized_args.keys():
935            raise RuntimeError("Input key is not exist, ", key)
936        self._normalized_args[key] = arg
937        self._sync_arg()
938
939    def get_args(self):
940        """
941        Get the arguments of current node.
942
943        - When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of
944          ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method.
945        - When node_type of current node is Input, arguments represents default-value of argument of function.
946        - When node_type of current node is Output, arguments represents return values.
947        - When node_type of current node is Python, arguments are don't-care.
948
949        Returns:
950            A list of instances of ScopedValue.
951        """
952        args = []
953        for arg_index in range(self._args_num):
954            args.append(self._normalized_args.get(self._normalized_args_keys[arg_index]))
955        return args
956
957    def get_kwargs(self):
958        """
959        Get the keyword arguments of current node.
960
961        - When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to
962          kwargs of ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()`
963          method.
964        - When node_type of current node is Python, Input or Output, keyword arguments are don't-care.
965
966        Returns:
967            A dict of str to instance of ScopedValue.
968        """
969        kwargs: {str, ScopedValue} = {}
970        for arg_index in range(self._args_num, self._args_num + self._kwargs_num):
971            key = self._normalized_args_keys[arg_index]
972            kwargs[key] = self._normalized_args.get(key)
973        return kwargs
974
975    def get_normalized_args(self) -> {str: ScopedValue}:
976        """
977        Get the normalized keyword arguments of current node.
978        Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
979        key of arguments.
980
981        Returns:
982            A dict of str to instance of ScopedValue.
983        """
984        output = {}
985        for key in self._normalized_args_keys:
986            output[key] = self._normalized_args.get(key)
987        return output
988
989    def set_normalized_args(self, args: {str, ScopedValue}):
990        """
991        Set the normalized keyword arguments of current node.
992        Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
993        key of arguments.
994
995        Args:
996            args ({str, ScopedValue}): A dict of str to instance of ScopedValue represents new normalized_args.
997        """
998        if len(args.values()) != len(self._normalized_args_keys):
999            raise RuntimeError("Length of args.values() should be equal to length of _normalized_args_keys, ",
1000                               len(args.values()), " vs ", len(self._normalized_args_keys))
1001        for key, arg in args.items():
1002            self._normalized_args[key] = arg
1003        self._sync_arg()
1004
1005    def set_attribute(self, key: str, value):
1006        """
1007        Set attribute of current node.
1008
1009        Args:
1010            key (str): Key of new attribute.
1011            value (object): Value of new attribute.
1012        """
1013        self._attribute[key] = value
1014
1015    def set_attributes(self, attributes):
1016        """
1017        Set attributes of current node.
1018
1019        Args:
1020            attributes (dict): A dict represents new attributes.
1021        """
1022        self._attribute = attributes
1023
1024    def get_attributes(self):
1025        """
1026        Get all attributes of current node.
1027
1028        Returns:
1029            A dict of str to instance of object as attributes.
1030        """
1031        return self._attribute
1032
1033    def get_attribute(self, key: str):
1034        """
1035        Get attribute of current node by key.
1036
1037        Args:
1038            key (str): A str represents key of attribute you want to get.
1039
1040        Returns:
1041            A object as attribute.
1042        """
1043        return self._attribute.get(key)
1044
1045    def get_arg_providers(self) -> dict:
1046        """
1047        Getter of _arg_providers.
1048
1049        Return:
1050            dict, key is type of int indicating the index of args, and value is type of tuple, which includes
1051                the node and the index of node's targets who provides the argument.
1052        """
1053        return self._arg_providers
1054
1055    def set_arg_providers(self, index: int, provider: tuple):
1056        """
1057        Setter of _arg_providers.
1058
1059        Args:
1060            index (int): Indicating provider of which argument need to be set.
1061            provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
1062        """
1063        self._arg_providers[index] = provider
1064
1065    def get_target_users(self, index=-1) -> Union[dict, list]:
1066        """
1067        Getter of _target_users.
1068
1069        Args:
1070            index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
1071                be returned.
1072
1073        Return:
1074            Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
1075                The type of elements in list is tuple, which includes the user node and the index of node's arguments
1076                who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
1077                value is list of users of corresponding target.
1078        """
1079        if index == -1:
1080            return self._target_users
1081        if index not in self._target_users.keys():
1082            self._target_users[index] = []
1083        return self._target_users.get(index, None)
1084
1085    def append_target_users(self, index: int, provider: tuple):
1086        """
1087        Setter of _target_users.
1088
1089        Args:
1090            index (int): Indicating users of which target need to be append.
1091            provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
1092
1093        """
1094        if index not in self._target_users.keys():
1095            self._target_users[index] = []
1096        self._target_users.get(index).append(provider)
1097
1098    def update_ast_node(self) -> ast.AST:
1099        """Update node's ast_node by current targets, func_name, args and kwargs."""
1100        ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
1101                                                    self.get_args(), self.get_kwargs())
1102        self.set_ast(ast_assign)
1103        return ast_assign
1104
1105    def get_source_code(self) -> str:
1106        """Get source code of node from ast of node."""
1107        return astunparse.unparse(self._ast_node).strip()
1108
1109    def append_kwarg(self, kwarg: Dict[str, ScopedValue]):
1110        """
1111        Append a new keyword arg to node.
1112
1113        Args:
1114            kwarg (Dict[str, ScopedValue]): The new keyword arg.
1115
1116        """
1117        if self.get_node_type() not in [NodeType.Tree, NodeType.CallFunction]:
1118            raise TypeError(f"For append_new_kwarg, the type of node can only be one of [Tree, CallFunction], "
1119                            f"but got {self.get_node_type()}")
1120        Validator.check_element_type_of_dict("kwarg", kwarg, [str], [ScopedValue], "append_new_kwarg")
1121        for arg_key, value in kwarg.items():
1122            # add keyword into _normalized_args
1123            self._normalized_args[arg_key] = value
1124            self._normalized_args_keys.append(arg_key)
1125            self._kwargs_num += 1
1126            # add keyword ast into ast.Call
1127            ast_assign: ast.Assign = self._ast_node
1128            ast_call: ast.Call = ast_assign.value
1129            new_keyword = ast.keyword(arg=arg_key, value=AstModifier.get_ast_by_value(value, None))
1130            ast_call.keywords.append(new_keyword)
1131
1132    def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
1133        """
1134        Merge args and kwargs to normalized args.
1135        The keys of args are obtained from the construct function of type(self._instance).
1136
1137        Args:
1138            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1139            kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1140
1141        Raises:
1142            RuntimeError: Input args are invalid.
1143            RuntimeError: Arg name already exist in kwargs.
1144
1145        Returns:
1146            The normalized args.
1147        """
1148        if not args:
1149            args = []
1150        if not kwargs:
1151            kwargs = {}
1152        normalized_args: dict = dict()
1153        if (args or kwargs) and self._instance and hasattr(type(self._instance), "construct"):
1154            parameters = inspect.signature(type(self._instance).construct).parameters
1155            names = Node._get_construct_arg_names(parameters)
1156            Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
1157        else:
1158            logger.debug("fail to get arg name from op, using arg_xx for args' name")
1159            arg_temp_name, suffix = "arg", 0
1160            for arg in args:
1161                arg_key = "{}_{}".format(arg_temp_name, suffix)
1162                while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
1163                    suffix += 1
1164                    arg_key = "{}_{}".format(arg_temp_name, suffix)
1165                normalized_args[arg_key] = arg
1166                self._normalized_args_keys.append(arg_key)
1167            for arg_key, value in kwargs.items():
1168                normalized_args[arg_key] = value
1169                self._normalized_args_keys.append(arg_key)
1170        return normalized_args
1171
1172    # Synchronize rewrite node args to ast node
1173    def _sync_assign_func_name_to_ast(self):
1174        """Sync func_name of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
1175        if self._ast_node is None:
1176            return
1177        assign_ast = self._ast_node
1178        if not isinstance(assign_ast, ast.Assign):
1179            raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1180        call_ast = assign_ast.value
1181        if not isinstance(call_ast, ast.Call):
1182            raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
1183        if self._func_name.type == ValueType.UnsupportedValue:
1184            return
1185        func_ast = call_ast.func
1186        if not self._func_name.scope:
1187            if isinstance(func_ast, ast.Name):
1188                func_ast.id = self._func_name.value
1189            else:
1190                call_ast.func = ast.Name(self._func_name.value, ast.Store())
1191        else:
1192            if isinstance(func_ast, ast.Attribute):
1193                if not isinstance(func_ast.value, ast.Name):
1194                    func_ast.value = ast.Name(self._func_name.scope, ast.Load())
1195                else:
1196                    func_ast.value.id = self._func_name.scope
1197                func_ast.attr = self._func_name.value
1198            else:
1199                call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
1200                                              self._func_name.value, ast.Store())
1201        ast.fix_missing_locations(assign_ast)
1202
1203    def _sync_assign_targets_to_ast(self):
1204        """Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
1205        if self._ast_node is None:
1206            return
1207        assign_ast = self._ast_node
1208        if not isinstance(assign_ast, ast.Assign):
1209            raise TypeError(error_str(f"assign_ast should be ast.Assign, but got: {type(assign_ast)}",
1210                                      father_node=assign_ast))
1211        # update targets
1212        target_ast_elems = AstConverter.get_ast_target_elems(assign_ast.targets[0])
1213        if len(self._targets) != len(target_ast_elems):
1214            raise ValueError(error_str(f"The number of targets should be {len(target_ast_elems)}, "
1215                                       f"but got {len(self._targets)}", father_node=assign_ast))
1216        for i, target_ast in enumerate(target_ast_elems):
1217            target_ast_elems[i] = AstModifier.get_ast_by_value(self._targets[i], target_ast)
1218
1219    def _sync_call_args_to_ast(self):
1220        """Sync args of ast.Call from self._normalized_args."""
1221        if self._ast_node is None:
1222            return
1223        assign_ast = self._ast_node
1224        if not isinstance(assign_ast, ast.Assign):
1225            raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node should be "
1226                            f"ast.Assign, but got: {type(assign_ast)}")
1227        assign_value = assign_ast.value
1228        if not isinstance(assign_value, ast.Call):
1229            if isinstance(assign_value, ast.Attribute) and self._node_type in (NodeType.CellContainer,
1230                                                                               NodeType.CallCell):
1231                # CellContainers in control flow may be flatten to ast.Attribute: blocks_var = self.blocks
1232                # In this case, no args exist in node, so we don't need to sync.
1233                # CellContainers may be type of CallCell when share one implementation
1234                return
1235            raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node.value should "
1236                            f"be ast.Call, but got: {type(assign_value)}")
1237        keywords_ast = assign_value.keywords
1238        args_ast = assign_value.args
1239        if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
1240            raise ValueError("ast keywords plus args len is not equal to self._normalized_args value")
1241        for arg_index in range(self._args_num):
1242            arg_ast = args_ast[arg_index]
1243            args_ast[arg_index] = \
1244                AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
1245
1246        # the order of kwargs may not the same as that in keywords_ast
1247        keyword_map_index = {}
1248        for index, keyword_ast in enumerate(keywords_ast):
1249            keyword_map_index[keyword_ast.arg] = index
1250        for keyword_index in range(self._kwargs_num):
1251            key = self._normalized_args_keys[keyword_index + self._args_num]
1252            keywords_ast[keyword_map_index.get(key)].value = \
1253                AstModifier.get_ast_by_value(self._normalized_args.get(key),
1254                                             keywords_ast[keyword_map_index.get(key)].value)
1255
1256    def _sync_call_method_args_to_ast(self):
1257        """
1258        Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
1259        For node with type of CallMethod, the value of ast.Assign is one of:
1260        | func_name      | data_type   | value of ast.Assign     |
1261        |:---------------|:------------|:------------------------|
1262        | 'pass_through' | constants   | ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str |
1263        | 'pass_through' | variables   | ast.Name, ast.Attribute |
1264        | 'tuple'        | tuple       | ast.Tuple               |
1265        | 'list'         | list        | ast.List                |
1266        | 'dict'         | dict        | ast.Dict                |
1267        """
1268        if self._ast_node is None:
1269            return
1270        assign_ast = self._ast_node
1271        if not isinstance(assign_ast, ast.Assign):
1272            raise TypeError(f"For node '{self.get_name()}', assign_ast should be ast.Assign, got: ", type(assign_ast))
1273        assign_value = assign_ast.value
1274        if self._func_name.value == "pass_through":
1275            # update constants/variables
1276            assign_ast.value = \
1277                AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), assign_value)
1278        elif self._func_name.value in ("tuple", "list", "dict"):
1279            # update tuple/list/dict
1280            ast_elts = assign_value.values if isinstance(assign_value, ast.Dict) else assign_value.elts
1281            if len(self._normalized_args_keys) != len(ast_elts):
1282                raise ValueError(f"For node '{self.get_name()}', size of self._normalized_args_keys"
1283                                 f"({len(self._normalized_args_keys)}) should be equal to size of elements of "
1284                                 f"ast_elts({len(ast_elts)})")
1285            for index, elt in enumerate(ast_elts):
1286                scoped_value: ScopedValue = self._normalized_args.get(self._normalized_args_keys[index])
1287                ast_elts[index] = AstModifier.get_ast_by_value(scoped_value, elt)
1288        else:
1289            raise TypeError(f"For node '{self.get_name()}', only support (pass_through, tuple or dict method) as "
1290                            f"call_method, but got {self._func_name.value}")
1291
1292    def _sync_return_node_to_ast(self):
1293        """
1294        Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
1295
1296        For node with type of CallMethod, the value of ast.Assign is one of:
1297        (ast.Name, ast.Attribute)
1298        """
1299        if self._ast_node is None:
1300            return
1301        return_ast = self._ast_node
1302        if not isinstance(return_ast, ast.Return):
1303            raise TypeError(f"For node '{self.get_name()}', return_ast should be ast.Return, got: {type(return_ast)}")
1304        return_value_ast = return_ast.value
1305        return_ast.value = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1306                                                        return_value_ast)
1307
1308    def _sync_mathops_node_args_to_ast(self):
1309        """
1310        Sync values from self._normalized_args to the ast node for mathematical operations.
1311        """
1312        if self._ast_node is None:
1313            return
1314        if not isinstance(self._ast_node, ast.Assign):
1315            raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
1316        mathops_node = self._ast_node.value
1317        if isinstance(mathops_node, ast.BinOp):
1318            left = mathops_node.left
1319            right = mathops_node.right
1320            mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1321                                                             left)
1322            mathops_node.right = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[1]),
1323                                                              right)
1324        elif isinstance(mathops_node, ast.UnaryOp):
1325            operand = mathops_node.operand
1326            mathops_node.operand = \
1327                AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
1328        elif isinstance(mathops_node, ast.BoolOp):
1329            values = mathops_node.values
1330            for arg_index in range(self._args_num):
1331                arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1332                values[arg_index] = AstModifier.get_ast_by_value(arg_value, values[arg_index])
1333        elif isinstance(mathops_node, ast.Compare):
1334            left = mathops_node.left
1335            mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
1336                                                             left)
1337            comparators = mathops_node.comparators
1338            for arg_index in range(1, self._args_num):
1339                arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1340                comparators[arg_index - 1] = AstModifier.get_ast_by_value(arg_value, comparators[arg_index - 1])
1341        else:
1342            raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
1343                            "ast.BoolOp, ast.Compare), but got ", type(mathops_node))
1344
1345    def _sync_control_flow_args_to_ast(self):
1346        """
1347        Sync values from self._normalized_args to the ast node of control flow.
1348        """
1349        if self._ast_node is None:
1350            return
1351        normalized_args_num = len(self._normalized_args_keys)
1352        if normalized_args_num == 0:
1353            return
1354        if normalized_args_num > 1:
1355            raise ValueError("self._normalized_args_keys should have less than 1 elements")
1356        arg_value = self._normalized_args.get(self._normalized_args_keys[0])
1357        if isinstance(self._ast_node, (ast.If, ast.IfExp, ast.While)):
1358            self._ast_node.test = AstModifier.get_ast_by_value(arg_value, self._ast_node.test)
1359        elif isinstance(self._ast_node, ast.For):
1360            self._ast_node.iter = AstModifier.get_ast_by_value(arg_value, self._ast_node.iter)
1361        else:
1362            raise ValueError(f"For Control Flow, ast_node should be one of [ast.If, ast.IfExp, "
1363                             f"ast.While, ast.For], but got {type(self._ast_node)}")
1364
1365    def _sync_arg(self):
1366        """Sync _normalized_args to corresponding ast node when updated."""
1367        if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
1368                               NodeType.CellContainer, NodeType.CallFunction):
1369            self._sync_call_args_to_ast()
1370        elif self._node_type == NodeType.Output:
1371            self._sync_return_node_to_ast()
1372        elif self._node_type == NodeType.CallMethod:
1373            self._sync_call_method_args_to_ast()
1374        elif self._node_type == NodeType.MathOps:
1375            self._sync_mathops_node_args_to_ast()
1376        elif self._node_type == NodeType.ControlFlow:
1377            self._sync_control_flow_args_to_ast()
1378
1379
1380# Child classes
1381class TreeNode(Node):
1382    """Tree type Node who holds a handler of SymbolTree."""
1383
1384    def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1385                 args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1386        """
1387        Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
1388        TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
1389
1390        Args:
1391            tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1392            ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1393            targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1394            func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1395            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1396            kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1397            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1398                Name of node also used as field name in network class.
1399            instance: Object in network corresponding to this node.
1400        """
1401        if isinstance(func, str):
1402            func = ScopedValue.create_naming_value(func)
1403        super().__init__(NodeType.Tree, ast_node, targets, func, args, kwargs, name, instance)
1404        self.symbol_tree = tree
1405
1406    @classmethod
1407    def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
1408                         func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
1409                         name: str = "", instance=None):
1410        """
1411        Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
1412        to sub-network.
1413
1414        Args:
1415            tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1416            ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1417            targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1418            func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1419            args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1420            kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1421            name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1422                Name of node also used as field name in network class.
1423            instance: Object in network corresponding to this node.
1424        """
1425        new_targets = Node._handle_targets(targets)
1426        if isinstance(func_name, str):
1427            func_name = ScopedValue.create_naming_value(func_name)
1428        return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)
1429