• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""SymbolTree builder."""
16from typing import Optional
17import ast
18import inspect
19from textwrap import dedent
20
21from mindspore.nn import Cell
22from .symbol_tree import SymbolTree
23from ..parsers import Parser, ParserRegister
24from ..ast_helpers import AstFlattener
25
26
27class SymbolTreeBuilder:
28    """
29    `SymbolTreeBuilder` for building a SymbolTree from network.
30
31    Args:
32         network (Cell): An instance of Cell represents a network from which SymbolTree will be built.
33    """
34
35    # Entry function of the forward computation process
36    entry_functions = ["construct"]
37
38    def __init__(self, network: Cell):
39        if not isinstance(network, Cell):
40            raise TypeError("Type of network should be Cell, but got ", network)
41        self._origin_net = network
42        network_str = inspect.getsource(type(network))
43        self._ast_root: ast.Module = ast.parse(dedent(network_str))
44        self._root_tree: Optional[SymbolTree] = None
45
46    @staticmethod
47    def ast_transform(ast_root: ast.AST) -> ast.AST:
48        """
49        Optimize ast before parse.
50
51        Args:
52             ast_root (ast.AST): An instance of ast to be optimized.
53
54        Returns:
55             An instance of ast been optimized.
56        """
57        ast_root = AstFlattener().transform(ast_root, SymbolTreeBuilder.entry_functions)
58        return ast_root
59
60    def build(self) -> SymbolTree:
61        """
62        Build SymbolTree.
63
64        Returns:
65             An instance of SymbolTree.
66        """
67
68        self._ast_root = SymbolTreeBuilder.ast_transform(self._ast_root)
69        if not isinstance(self._ast_root, ast.Module):
70            raise TypeError("Type of ast_root should be ast.Module, but got ", self._ast_root)
71        self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
72        parser: Parser = ParserRegister.instance().get_parser(ast.Module)
73        parser.process(self._root_tree, self._ast_root, None)
74        ast.fix_missing_locations(self._root_tree.get_module_ast())
75        self._root_tree.finish_build()
76        return self._root_tree
77