• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Parse ast.Assign in construct function to node of SymbolTree."""
16from typing import Union, List, Dict
17import types
18import os
19import ast
20import sys
21import inspect
22import builtins
23from textwrap import dedent
24
25from mindspore import log as logger
26from mindspore.nn import Cell, SequentialCell, CellList
27from mindspore.ops.primitive import Primitive
28import mindspore.ops.functional as F
29from . import Parser, ParserRegister, reg_parser
30from ..symbol_tree import SymbolTree
31from ..node import Node, TreeNode, NodeManager, CallFunction, CellContainer, ControlFlow, LocalPrim
32from ..api.scoped_value import ScopedValue
33from ..ast_helpers import AstFlattener, AstConverter, AstFinder
34from ..common.error_log import error_str
35from ..common.namespace import is_subtree, is_ms_function, is_third_party
36from ..common.namer import FunctionNamer
37
38
39if sys.version_info >= (3, 9):
40    import ast as astunparse # pylint: disable=reimported, ungrouped-imports
41else:
42    import astunparse
43
44
45class AssignParser(Parser):
46    """Parse ast.Assign in construct function to node of SymbolTree."""
47
48    # Types for creating Cell Container node
49    types_for_cell_container = [SequentialCell,]
50    # If mindspore built-in function to be parsered or skipped
51    _skip_ms_function = False
52    # Functions in black list will not be parsed
53    _function_parse_black_list = [F.arange]
54    # Share one implementation for the same instances
55    _share_one_implementation = False
56    # Implementation caches of sub SymbolTrees, CallFunction nodes and CellContainer nodes
57    # Keys are ids of the instance object
58    _cached_trees: Dict[int, SymbolTree] = {}
59    _cached_functions: Dict[int, Node] = {}
60    _cached_cell_containers: Dict[int, Node] = {}
61
62    def __init__(self):
63        super().__init__()
64        self._variables_cache = []
65        self.stree: SymbolTree = None
66        self.ast_assign: ast.Assign = None
67        self.node_manager: NodeManager = None
68        self.targets: List[ScopedValue] = None
69        self.args: List[ScopedValue] = None
70        self.kwargs: Dict[str, ScopedValue] = None
71
72    @staticmethod
73    def _get_func_name(ast_call: ast.Call) -> str:
74        """
75        Get the func name from ast.Call.
76
77        Args:
78            ast_call (ast.Call): Input ast.Call node.
79
80        Returns:
81            Func name.
82        """
83        func = ast_call.func
84        if isinstance(func, ast.Name):
85            return func.id
86        if isinstance(func, ast.Attribute):
87            return func.attr
88        func_full_name = astunparse.unparse(func).strip()
89        if func_full_name.count('.') > 0:
90            return func_full_name.split('.')[-1]
91        return func_full_name
92
93    @staticmethod
94    def _get_func_scope(ast_call: ast.Call) -> str:
95        """
96        Get the func scope from ast.Call.
97
98        Args:
99            ast_call (ast.Call): Input ast.Call node.
100
101        Returns:
102            Func scope.
103        """
104        func = ast_call.func
105        if isinstance(func, ast.Name):
106            return ""
107        func_full_name = astunparse.unparse(func).strip()
108        if func_full_name.count('.') > 0:
109            return func_full_name.rsplit('.', 1)[0]
110        return ""
111
112    @staticmethod
113    def _create_targets(ast_target: ast.AST) -> List[ScopedValue]:
114        """Get targets from ast node."""
115        ast_target_elems = AstConverter.get_ast_target_elems(ast_target)
116        targets = [AstConverter.create_scopedvalue(ast_node) for ast_node in ast_target_elems]
117        return targets
118
119    @staticmethod
120    def _create_kwargs(keywords: [ast.keyword]) -> Dict[str, ScopedValue]:
121        """
122        Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node.
123
124        Args:
125            keywords ([ast.keyword]): Keywords of ast.Call node.
126
127        Returns:
128            A dict of ScopedValue.
129        """
130        results = {}
131        for keyword in keywords:
132            results[keyword.arg] = AstConverter.create_scopedvalue(keyword.value)
133        return results
134
135
136    @staticmethod
137    def _get_inst_and_name(ast_node: ast.Attribute, stree: SymbolTree):
138        """
139        Try to get instance object of ast_node from ast.Attribute.
140        """
141        if not isinstance(ast_node, ast.Attribute):
142            return None, ""
143        scope_name = astunparse.unparse(ast_node).strip()
144        scope, name = scope_name.split('.', 1)
145        if scope != 'self':
146            return None, scope_name
147        if not hasattr(stree.get_origin_network(), name):
148            return None, scope_name
149        return getattr(stree.get_origin_network(), name), scope_name
150
151    @staticmethod
152    def _list_of_cells(cell_list: list):
153        """Check if elements in the list are all cells."""
154        for item in cell_list:
155            if not isinstance(item, Cell):
156                return False
157        return True
158
159    @staticmethod
160    def _get_path_of_node_manager(node_manager: NodeManager):
161        """Get file path of type(instance) in NodeManager"""
162        node_manager = node_manager.get_top_manager()
163        if isinstance(node_manager, SymbolTree):
164            return inspect.getfile(type(node_manager.get_origin_network()))
165        return inspect.getfile(node_manager.get_instance())
166
167    @staticmethod
168    def _get_module_of_node_manager(node_manager: NodeManager):
169        """Get module where the node manager is located"""
170        # get module where function object is used
171        func_path = AssignParser._get_path_of_node_manager(node_manager)
172        func_path = os.path.normcase(os.path.normpath(func_path))
173        modules = list(sys.modules.values())
174        for m in modules:
175            if hasattr(m, "__file__") and m.__file__ is not None and func_path == os.path.normcase(m.__file__):
176                return m, func_path
177        return None, func_path
178
179    @staticmethod
180    def _get_object_from_module(func_full_name: str, module: types.ModuleType):
181        """Get object from module according to full name of function"""
182        names = func_full_name.split('.')
183        obj = module
184        for attr in names:
185            if not hasattr(obj, attr):
186                logger.info(f"For '{func_full_name}', failed to get attr '{attr}' from '{obj}'")
187                return None
188            obj = getattr(obj, attr)
189        return obj
190
191    @staticmethod
192    def _get_local_var_provider(node_manager: NodeManager, var: str) -> Node:
193        """Get the node providing specific variable"""
194        node = node_manager.get_tail()
195        while node is not None:
196            if var in [str(target) for target in node.get_targets()]:
197                return node
198            node = node.get_prev()
199        # When node_manager is control flow, nodes in upper node_manager need to be traversed.
200        if isinstance(node_manager, ControlFlow):
201            return AssignParser._get_local_var_provider(node_manager.get_node_manager(), var)
202        return None
203
204    def target(self):
205        """Parse target type."""
206        return ast.Assign
207
208    def store_env(self):
209        """Store current environments"""
210        self._variables_cache.append(
211            [self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs])
212        self.stree = None
213        self.ast_assign = None
214        self.node_manager = None
215        self.targets = None
216        self.args = None
217        self.kwargs = None
218
219    def restore_env(self):
220        """Restore last environments"""
221        self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs = \
222            self._variables_cache.pop()
223
224    def _get_cell_instance(self, func_scope, func_name):
225        """
226        Get object instance from ast.Call with type of Cell.
227
228        Args:
229            func_scope (str): Func scope.
230            func_name (str): Func name.
231
232        Returns:
233            An instance represents operator instance.
234        """
235        if func_scope != "self":
236            return None
237        var_dict = self.stree.get_origin_network().__dict__
238        # Instance is of type Cell
239        for key, value in var_dict["_cells"].items():
240            if key == func_name:
241                return value
242        # Instance is of other type.
243        return None
244
245    def _get_primitive_instance(self, func_scope, func_name):
246        """
247        Get object instance from ast.Call with type of Primitive.
248
249        Args:
250            func_scope (str): Func scope.
251            func_name (str): Func name.
252
253        Returns:
254            An instance represents operator instance.
255        """
256        if func_scope != "self":
257            return None
258        var_dict = self.stree.get_origin_network().__dict__
259        # Instance is of type Primitive
260        for key, value in var_dict["_primitives"].items():
261            if key == func_name:
262                return value
263        # Instance is of other type.
264        return None
265
266    def _get_method_object(self, func_scope, func_name):
267        """Get method object from network instance."""
268        stree = self.stree
269        if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name):
270            return getattr(stree.get_origin_network(), func_name)
271        return None
272
273    def _get_local_variable(self, func_scope, func_name) -> (bool, object):
274        """
275        Get local variable
276
277        Args:
278            func_scope (str): Func scope.
279            func_name (str): Func name.
280
281        Returns:
282            bool: Indicate whether local variable is found.
283            object (Union[LocalPrim, type]): Instance of LocalPrim when calling the class, or class type
284                object when initializing the class.
285        """
286        func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
287        # try to find func_name in class variables initializing the primitive during forward method
288        provider_node = None
289        if func_scope == "self":
290            for node in self.stree.local_prim_inits():
291                if func_full_name in [str(target) for target in node.get_targets()]:
292                    provider_node = node
293        # try to find func_name in local variables
294        if provider_node is None:
295            provider_node = AssignParser._get_local_var_provider(self.node_manager, func_full_name)
296        if provider_node:
297            # when the node providering the local variable initialized a primitive during forward method,
298            # we use LocalPrim to indicate the instance of this primitive. e.g. :
299            # abs_inst = P.Abs()  -> 'abs_inst' is an instance of primitive initialized locally
300            # y = abs_inst(x)     -> here we are parsing now
301            cls_init = provider_node.get_init_cls()
302            if cls_init and inspect.isclass(cls_init) and issubclass(cls_init, Primitive):
303                return True, LocalPrim(cls_init)
304            # when the node providering the local variable represent a primitive type object, we return
305            # type-object to indicate that we are initializing this primitive. e.g. :
306            # abs_ops = _get_cache_prim(P.Abs)  -> 'abs_ops' is an primitive type object
307            # y = abs_ops(x)                    -> here we are parsing now
308            cls_type = provider_node.get_type_cls()
309            if cls_type and inspect.isclass(cls_type) and issubclass(cls_type, Primitive):
310                return True, cls_type
311            # local variable whose type is not primitive instance
312            logger.info(f"Ignore local variable: {func_full_name}")
313            return True, None
314        # other local variable
315        if AssignParser._get_local_var_provider(self.node_manager, func_full_name.split('.')[0]):
316            logger.info(f"Ignore local variable: {func_full_name}")
317            return True, None
318        return False, None
319
320    def _get_function_object(self, func_scope, func_name, ast_call) -> (object, bool):
321        """
322        Get function object from module.
323
324        If the code represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs),
325        return primitive type object with class type flag True.
326
327        if the code represent an initializtion of a class, e.g. abs_inst = P.Abs(),
328        return primitive type object with class type flag False.
329
330        if the code represent the call of function or class instance, e.g. y = abs_inst(x)/func(x),
331        return primitive instance or function object with class type flag False.
332
333        Args:
334            func_scope (str): Func scope.
335            func_name (str): Func name.
336            ast_call (ast.Call): ast.Call of ast.Assign.
337
338        Returns:
339            object: Class type object, class instance or function object
340            bool: Flag indicate is node represent a class type object.
341        """
342        func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
343        # get module where function object is used
344        module, func_path = AssignParser._get_module_of_node_manager(self.node_manager)
345        if module is None:
346            logger.debug(f"When getting object of '{func_full_name}', failed to find module in '{func_path}'")
347            return None, False
348        # if name of function is _get_cache_prim, return primitive type object
349        is_cls_type_obj = False
350        if func_full_name == '_get_cache_prim':
351            func_full_name = astunparse.unparse(ast_call.args[0]).strip()
352            is_cls_type_obj = True
353        # find object in module
354        obj = AssignParser._get_object_from_module(func_full_name, module)
355        return obj, is_cls_type_obj
356
357    def _update_field_in_init(self, func_name: str, sub_tree: SymbolTree) -> bool:
358        """
359        When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
360        Add the code like: `self.field = SubNetwork(self.field)`
361
362        Args:
363            func_name (str): A string represents scope and name of function symbol.
364            sub_tree (SymbolTree): The SymbolTree corresponding to sub-network.
365        """
366        init_func_ast = self.stree.get_init_func_ast()
367        sub_net_obj = sub_tree.get_origin_network()
368        sub_net_opt_name = sub_tree.get_opt_cls_name()
369        # Add .to_float(mindspore.float16) if origin subnet has this attribute
370        new_code = f"{func_name} = {sub_net_opt_name}({func_name})"
371        if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16:
372            new_code = f"{new_code}.to_float(mindspore.float16)"
373        elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16:
374            new_code = f"{new_code}.to_float(mindspore.bfloat16)"
375        new_ast = ast.parse(new_code).body[0]
376        init_func_ast.body.append(new_ast)
377
378    def _update_cell_container_in_init(self, container_name, container_idx, subnet_opt_name):
379        """
380        When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object.
381        So the assign code will be got from origin code first, and then be modified to new class name.
382
383        Codes like:
384
385        `self.container = nn.SequentialCell([ReLU(), MyNet()])`
386
387        will be updated by add codes:
388
389        `self.container[1] = MyNetOpt(self.container[1])`
390
391        """
392        new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])"
393        new_ast = ast.parse(new_code).body[0]
394        self.stree.get_init_func_ast().body.append(new_ast)
395
396    def _add_import(self, import_name: str):
397        """ add import to current node manager."""
398        module, _ = AssignParser._get_module_of_node_manager(self.node_manager)
399        if module is None:
400            logger.info(f"Cannot get module where '{import_name}' is located, ignore import info")
401            return
402        node_manager = self.node_manager.get_top_manager()
403        belonging_ast = None if isinstance(node_manager, SymbolTree) else node_manager.get_manager_ast()
404        self.stree.add_import(module, import_name, belonging_ast)
405
406    def cell_container_process(self, func_name: str, node_name: str, container_obj: object):
407        """ parse cell container object."""
408        # create unparsable node if container is already parsed when sharing one implementation
409        if AssignParser._share_one_implementation and id(container_obj) in AssignParser._cached_cell_containers:
410            cell_container = Node.create_call_buildin_op(container_obj, self.ast_assign, self.targets,
411                                                         func_name, self.args, self.kwargs, node_name)
412            return cell_container
413        cell_container = CellContainer(self.ast_assign, self.targets, func_name, self.args, self.kwargs,
414                                       node_name, self.stree, container_obj)
415        for i, cell in enumerate(container_obj):
416            cell_name = type(cell).__name__
417            # The type of cell is container of cells (e.g. SequentialCell)
418            if isinstance(cell, tuple(AssignParser.types_for_cell_container)):
419                sub_node = self.cell_container_process(f"{func_name}[{i}]", cell_name, cell)
420            elif is_subtree(cell):
421                # create unparsable node if tree node is already parsed when sharing one implementation
422                if AssignParser._share_one_implementation and id(cell) in AssignParser._cached_trees:
423                    first_stree = AssignParser._cached_trees.get(id(cell))
424                    self._update_cell_container_in_init(func_name, i, first_stree.get_opt_cls_name())
425                    sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
426                                                           self.kwargs, cell_name)
427                else:
428                    from ..symbol_tree import SymbolTreeBuilder
429                    stb = SymbolTreeBuilder(cell)
430                    new_stree = stb.build()
431                    sub_node = TreeNode.create_tree_node(new_stree, None, self.targets, cell_name, self.args,
432                                                         self.kwargs, cell_name, cell)
433                    self._update_cell_container_in_init(func_name, i, new_stree.get_opt_cls_name())
434                    # save symbol tree if it is firstly parsed when sharing one implementation
435                    if AssignParser._share_one_implementation:
436                        AssignParser._cached_trees[id(cell)] = new_stree
437            else:
438                sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
439                                                       self.kwargs, cell_name)
440            # add sub node to cell_container
441            cell_container.append(sub_node, False)
442        # save the node if container is firstly parsed when sharing one implementation
443        if AssignParser._share_one_implementation:
444            AssignParser._cached_cell_containers[id(container_obj)] = cell_container
445        return cell_container
446
447    def process_cell(self, func_scope_name: ScopedValue, node_name: str, cell_inst: Cell):
448        """Create CallCell node with instance of cell."""
449        # The type of cell is container of cells (e.g. SequentialCell)
450        if isinstance(cell_inst, tuple(AssignParser.types_for_cell_container)):
451            node = self.cell_container_process(func_scope_name, node_name, cell_inst)
452        # The type of cell is user custom network, then we create sub-symboltree
453        elif is_subtree(cell_inst):
454            # create unparsable node if tree node is already parsed when sharing one implementation
455            if AssignParser._share_one_implementation and id(cell_inst) in AssignParser._cached_trees:
456                first_stree = AssignParser._cached_trees.get(id(cell_inst))
457                self._update_field_in_init(str(func_scope_name), first_stree)
458                node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name,
459                                                   self.args, self.kwargs, node_name)
460            else:
461                from ..symbol_tree import SymbolTreeBuilder
462                stb = SymbolTreeBuilder(cell_inst)
463                new_stree = stb.build()
464                self._update_field_in_init(str(func_scope_name), new_stree)
465                node = TreeNode.create_tree_node(new_stree, self.ast_assign, self.targets, func_scope_name,
466                                                 self.args, self.kwargs, node_name, new_stree.get_origin_network())
467                # save symbol tree if it is firstly parsed when sharing one implementation
468                if AssignParser._share_one_implementation:
469                    AssignParser._cached_trees[id(cell_inst)] = new_stree
470        else:
471            # The type of cell is built-in cells
472            node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name, self.args,
473                                               self.kwargs, node_name)
474        self.stree.append_origin_field(node, self.node_manager)
475
476    def process_primitive(self, func_scope_name: ScopedValue, node_name: str, primitive_inst: Primitive):
477        """Create CallPrimitive node with instance of primitive."""
478        node = Node.create_call_buildin_op(primitive_inst, self.ast_assign, self.targets, func_scope_name,
479                                           self.args, self.kwargs, node_name)
480        self.stree.append_origin_field(node, self.node_manager)
481
482    def process_class_method(self, func_scope_name: ScopedValue, node_name: str, method_object: object):
483        """Create CallFunction node for class method function."""
484        func_name = func_scope_name.value
485        # get ast.FunctionDef
486        ast_functiondef = None
487        for body in self.stree.get_class_ast().body:
488            if isinstance(body, ast.FunctionDef) and func_name == body.name:
489                ast_functiondef = body
490        if ast_functiondef is None:
491            # method of child class may be called and will be ignored now.
492            logger.info(error_str(f"Find ast of function '{func_name}' in network '{self.stree.get_ori_cls_name()}' "
493                                  f"failed", child_node=self.ast_assign))
494            self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
495        else:
496            # create CallFunction node
497            self.insert_callfunction_node(func_scope_name, node_name, ast_functiondef, method_object, True)
498
499    def process_function(self, func_scope_name: ScopedValue, node_name: str, function_object: object,
500                         is_cls_type_obj: bool):
501        """Create node for function."""
502        # Ignore functions in _function_parse_black_list
503        if function_object in AssignParser._function_parse_black_list:
504            logger.debug(f"'{func_scope_name}' is in the _function_parse_black_list and will not be parsed")
505            if not func_scope_name.scope:
506                self._add_import(func_scope_name.value)
507            self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
508            return
509        # break loop function
510        node_manager = self.node_manager
511        while node_manager and isinstance(node_manager, Node):
512            if isinstance(node_manager, CallFunction) and node_manager.get_instance() == function_object:
513                logger.info(f"loop function detected in '{func_scope_name}', stop parsing function.")
514                self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
515                return
516            node_manager = node_manager.get_node_manager()
517        # process primitive instances:
518        # (global/local) _ops_func = P.FUNC()
519        # (here) y = _ops_func(x) <- (process: _ops_func)
520        if isinstance(function_object, Primitive):
521            # when primitive instance is not a local variable, it will be a global object which need to be imported
522            if not isinstance(function_object, LocalPrim):
523                import_name = str(func_scope_name).split('.')[0]
524                self._add_import(import_name)
525            # create CallPrimitive node
526            self.process_primitive(func_scope_name, func_scope_name.value, function_object)
527            return
528        # process primitive object:
529        # (here) _ops_func = P.FUNC() <- (process: P.FUNC)
530        # (later) y = _ops_func(x)
531        if inspect.isclass(function_object):
532            node = self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
533            if is_cls_type_obj:
534                # represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
535                node.set_type_cls(function_object)
536                # add import
537                if str(func_scope_name) == '_get_cache_prim':
538                    import_name = astunparse.unparse(self.ast_assign.value.args[0]).strip()
539                    if '.' not in import_name:
540                        self._add_import(import_name)
541            else:
542                # represent the initialize of a class type, e.g. abs_inst = P.Abs()
543                node.set_init_cls(function_object)
544                # record local primitive objects
545                if func_scope_name.scope == 'self' and issubclass(function_object, Primitive):
546                    self.stree.local_prim_inits.append(node)
547            return
548        # process third party functions
549        is_ms_func = is_ms_function(function_object)
550        if not is_ms_func and is_third_party(function_object):
551            logger.info(f"Ignore third party function '{func_scope_name}'.")
552            self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
553            return
554        # process mindspore functions
555        if is_ms_func and AssignParser._skip_ms_function:
556            logger.info(f"Ignore mindspore function '{func_scope_name}'.")
557            self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
558            return
559        # get ast.FunctionDef
560        source_code = inspect.getsource(function_object)
561        ast_functiondef = ast.parse(dedent(source_code)).body[0]
562        if not isinstance(ast_functiondef, ast.FunctionDef):
563            logger.info(error_str(f"Get ast.FunctionDef of function {str(func_scope_name)} failed, the type of "
564                                  f"ast node is {type(ast_functiondef)}", child_node=self.ast_assign))
565            self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
566            return
567        if [n for n in ast_functiondef.body if isinstance(n, ast.FunctionDef)]:
568            logger.info(error_str(f"closure syntax is not supported now, {str(func_scope_name)} will not be parsed.",
569                                  child_node=ast_functiondef))
570            if not func_scope_name.scope:
571                self._add_import(func_scope_name.value)
572            self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
573            return
574        # update func_name, and remove scope
575        new_name = ast_functiondef.name
576        # when func_scope_name(e.g. 'C.uniform') is not the name in ast.FunctionDef(e.g. 'uniform'), this name may be
577        # already used as variable(e.g. uniform = C.uniform(x)).
578        # To avoid new function's name being duplicated with existed variable, an suffix '_opt' will be added.
579        if new_name != str(func_scope_name):
580            new_name = f"{new_name}_opt"
581        new_name = FunctionNamer().instance().get_name(new_name)
582        # create unparsable node if function is already parsed when sharing one implementation
583        if AssignParser._share_one_implementation and id(function_object) in AssignParser._cached_functions:
584            first_node = AssignParser._cached_functions.get(id(function_object))
585            ast_call: ast.Call = self.ast_assign.value
586            ast_call.func = ast.Name(id=str(first_node.get_func_name()), ctx=ast.Load())
587            self.insert_callfunction_node(func_scope_name, new_name, None, function_object, False)
588            return
589        ast_functiondef.name = new_name
590        ast_call: ast.Call = self.ast_assign.value
591        ast_call.func = ast.Name(id=new_name, ctx=ast.Load())
592        # save ast.FunctionDef into stree._external_ast
593        self.stree.get_external_ast()[ast_functiondef] = []
594        # import module which function defined in
595        func_file_path = inspect.getabsfile(function_object)
596        self.stree.save_imports_from_file(func_file_path, ast_functiondef)
597        # create CallFunction node
598        func_scope_name = ScopedValue.create_naming_value(new_name, "")
599        node = self.insert_callfunction_node(func_scope_name, new_name, ast_functiondef, function_object, False)
600        # save function node if it is firstly parsed when sharing one implementation
601        if AssignParser._share_one_implementation:
602            AssignParser._cached_functions[id(function_object)] = node
603
604    def insert_callfunction_node(self, func_name: ScopedValue, node_name: str, ast_functiondef: ast.FunctionDef,
605                                 func_obj: object, is_method: bool) -> Node:
606        """Create CallFunction node for function."""
607        if ast_functiondef is None:
608            node = Node.inner_create_call_function(node_name, self.ast_assign, func_name, func_obj,
609                                                   self.targets, self.args, self.kwargs)
610            self.stree.append_origin_field(node, self.node_manager)
611            return node
612        # create CallFunction node
613        node = CallFunction(self.targets, func_name, self.args, self.kwargs, node_name, self.ast_assign,
614                            ast_functiondef, self.stree, func_obj, is_method)
615        self.stree.append_origin_field(node, self.node_manager)
616        # expand ast codes
617        ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name.value], self.stree)
618        # parse ast codes into CallFunction Node
619        parser = ParserRegister.instance().get_parser(ast.FunctionDef)
620        parser.process(self.stree, ast_functiondef, node_manager=node)
621        return node
622
623    def process_ast_call(self, ast_call: ast.Call):
624        """
625        Convert ast.Call to a symbol tree node.
626
627        Args:
628            ast_call (ast.Call): An ast.Call of assign node in construct.
629        """
630        self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
631        self.args = [AstConverter.create_scopedvalue(arg) for arg in ast_call.args]
632        self.kwargs = AssignParser._create_kwargs(ast_call.keywords)
633        func_name = AssignParser._get_func_name(ast_call)
634        func_scope = AssignParser._get_func_scope(ast_call)
635        func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
636        func_full_name = str(func_scope_name)
637        # y = func(xxx)(xxx) / y = func1(xxx).func2(xxx) is not supported, and should be flattened before parsing.
638        if AstFinder(ast_call.func).find_all(ast.Call):
639            logger.info(error_str("ast.Call in func name of ast.Call is not supported.", ast_call, self.ast_assign))
640            self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
641            return
642        # Ignore built-in functions
643        if func_full_name in dir(builtins):
644            logger.info(f"Ignore built-in function: {func_scope_name}")
645            self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
646            return
647        # Ignore function name is target of for loop
648        if isinstance(self.node_manager, ControlFlow) and func_full_name in self.node_manager.loop_vars:
649            logger.info(f"Ignore function of loop variable: {func_scope_name}")
650            self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
651            return
652        # Instance with type of Cell
653        cell_inst = self._get_cell_instance(func_scope, func_name)
654        if cell_inst is not None:
655            self.process_cell(func_scope_name, func_name, cell_inst)
656            return
657        # Instance with type of Primitive
658        primitive_inst = self._get_primitive_instance(func_scope, func_name)
659        if primitive_inst is not None:
660            self.process_primitive(func_scope_name, func_name, primitive_inst)
661            return
662        # Class method object
663        method_object = self._get_method_object(func_scope, func_name)
664        if method_object is not None:
665            if inspect.ismethod(method_object):
666                self.process_class_method(func_scope_name, func_name, method_object)
667            elif isinstance(inspect.getattr_static(self.stree.get_origin_network(), func_name), staticmethod):
668                self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
669            else:
670                self.process_function(func_scope_name, func_name, method_object, False)
671            return
672        # Local variable
673        is_local_var, primitive_obj = self._get_local_variable(func_scope, func_name)
674        if primitive_obj is not None:
675            self.process_function(func_scope_name, func_name, primitive_obj, False)
676            return
677        if is_local_var:
678            # for a variable whose type is not primitive instance, create normal node for it
679            self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
680            return
681        # Function object
682        function_object, is_cls_type_obj = self._get_function_object(func_scope, func_name, ast_call)
683        if function_object is not None:
684            self.process_function(func_scope_name, func_name, function_object, is_cls_type_obj)
685            return
686        logger.info(error_str("Failed to get instance or object of ast.Call.", ast_call, self.ast_assign))
687        self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
688
689    def process_ast_mathops(self, ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]):
690        """
691        Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
692        a symbol tree node.
693
694        Args:
695            ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
696                operation in construct function.
697
698        Raises:
699            TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
700
701        """
702        if not isinstance(ast_op, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
703            raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, "
704                            "ast.BoolOp, ast.Compare), but got ", type(ast_op))
705
706        targets = AssignParser._create_targets(self.ast_assign.targets[0])
707        args = []
708        op_type_str = type(ast_op).__name__
709        op_type = ScopedValue.create_naming_value(op_type_str)
710        name = op_type_str
711        if isinstance(ast_op, ast.BinOp):
712            op = type(ast_op.op).__name__
713            name = f'{name}_{op}'
714            args.append(AstConverter.create_scopedvalue(ast_op.left))
715            args.append(AstConverter.create_scopedvalue(ast_op.right))
716        elif isinstance(ast_op, ast.UnaryOp):
717            op = type(ast_op.op).__name__
718            name = f'{name}_{op}'
719            args.append(AstConverter.create_scopedvalue(ast_op.operand))
720        elif isinstance(ast_op, ast.BoolOp):
721            op = type(ast_op.op).__name__
722            name = f'{name}_{op}'
723            for value in ast_op.values:
724                args.append(AstConverter.create_scopedvalue(value))
725        elif isinstance(ast_op, ast.Compare):
726            args.append(AstConverter.create_scopedvalue(ast_op.left))
727            for idx, ast_cmp_op in enumerate(ast_op.ops):
728                op = type(ast_cmp_op).__name__
729                name = f'{name}_{op}'
730                args.append(AstConverter.create_scopedvalue(ast_op.comparators[idx]))
731        name = name.lower()
732        node = Node.create_mathops_node(self.ast_assign, targets, op_type, args, name)
733        self.stree.append_origin_field(node, self.node_manager)
734
735    def process_ast_constant(self, ast_constant: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str]):
736        """
737        Convert ast node of constant types (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str) to
738        a symbol tree node.
739        """
740        node_name = f"{type(ast_constant).__name__.lower()}_assign"
741        targets = AssignParser._create_targets(self.ast_assign.targets[0])
742        args = [AstConverter.create_scopedvalue(ast_constant)]
743        node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, node_name)
744        self.stree.append_origin_field(node, self.node_manager)
745
746    def process_ast_name(self, ast_node: Union[ast.Name, ast.Attribute]):
747        """
748        Convert ast node of ast.Name and ast.Attribute to a symbol tree node.
749        """
750        self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
751        inst, scope_name = AssignParser._get_inst_and_name(ast_node, self.stree)
752        if inst is not None and (isinstance(inst, CellList) or
753                                 isinstance(inst, list) and AssignParser._list_of_cells(inst)):
754            node = self.cell_container_process(scope_name, scope_name, inst)
755        else:
756            node_name = f"{type(ast_node).__name__.lower()}_assign"
757            args = [AstConverter.create_scopedvalue(ast_node)]
758            node = Node.create_call_method(self.ast_assign, self.targets, "pass_through", args, {}, node_name)
759        self.stree.append_origin_field(node, self.node_manager)
760
761    def process_ast_tuple(self, ast_node: Union[ast.Tuple, ast.List]):
762        """
763        Convert ast node of ast.Tuple or ast.List to a symbol tree node.
764        """
765        # ensure that each element's type in tuple is supported by scopled value
766        if AstConverter.ast_tuple_elts_support_scopledvalue(ast_node):
767            targets = AssignParser._create_targets(self.ast_assign.targets[0])
768            args = []
769            for elt in ast_node.elts:
770                args.append(AstConverter.create_scopedvalue(elt))
771            func_name = "tuple" if isinstance(ast_node, ast.Tuple) else "list"
772            node = Node.create_call_method(self.ast_assign, targets, func_name, args, {}, func_name)
773            self.stree.append_origin_field(node, self.node_manager)
774        else:
775            logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
776                        "in rewrite, fallback to python")
777            self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
778
779    def process_ast_dict(self, ast_dict: ast.Dict):
780        """
781        Convert ast node of ast.Dict to a symbol tree node.
782        """
783        # ensure that each element's type in dict is supported by scopled value
784        if AstConverter.ast_dict_support_scopledvalue(ast_dict):
785            targets = AssignParser._create_targets(self.ast_assign.targets[0])
786            kwargs = {}
787            for idx, key in enumerate(ast_dict.keys):
788                kwargs[key.value] = AstConverter.create_scopedvalue(ast_dict.values[idx])
789            func_name = ScopedValue.create_naming_value("dict")
790            node = Node.create_call_method(self.ast_assign, targets, func_name, [], kwargs, "dict")
791            self.stree.append_origin_field(node, self.node_manager)
792        else:
793            logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
794                        "in rewrite, fallback to python")
795            self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
796
797    def process_ast_subscript(self, ast_subscript: ast.Subscript):
798        """
799        Convert ast node of ast.Subscript to a symbol tree node.
800        """
801        targets = AssignParser._create_targets(self.ast_assign.targets[0])
802        args = [AstConverter.create_scopedvalue(ast_subscript)]
803        node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, "subscript_var")
804        self.stree.append_origin_field(node, self.node_manager)
805
806    def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager):
807        """
808        Parse ast.Assign and create a node in symbol tree.
809
810        - Create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute].
811        - Create python node when value of ast.Assign is in [ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple,
812          ast.Dict].
813        - Other value types are not supported.
814
815        Args:
816            stree ([SymbolTree]): Symbol Tree under parsing.
817            node ([ast.Assign]): An ast.Assign node.
818            node_manager (NodeManager): NodeManager those asts belong to.
819        """
820        if len(node.targets) != 1:
821            logger.info(error_str(f"Continuous assignment statement(e.g. 'a = b = 1') should be flatten before.",
822                                  child_node=node))
823            stree.try_append_python_node(node, node, node_manager)
824            return
825
826        self.store_env()
827        self.stree = stree
828        self.ast_assign = node
829        self.node_manager = node_manager
830        value = node.value
831        if isinstance(value, ast.Call):
832            self.process_ast_call(value)
833        elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
834            self.process_ast_mathops(value)
835        elif isinstance(value, ast.Subscript):
836            self.process_ast_subscript(value)
837        elif isinstance(value, (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str)):
838            self.process_ast_constant(value)
839        elif isinstance(value, (ast.Name, ast.Attribute)):
840            self.process_ast_name(value)
841        elif isinstance(value, (ast.Tuple, ast.List)):
842            self.process_ast_tuple(value)
843        elif isinstance(value, ast.Dict):
844            self.process_ast_dict(value)
845        else:
846            logger.info(f"ops-call({astunparse.unparse(node).strip()}) in assign will be supported in near feature, "
847                        f"ignored as a python node now")
848            stree.try_append_python_node(node, node, node_manager)
849        self.restore_env()
850
851
852g_assign_parser = reg_parser(AssignParser())
853