1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""SymbolTree builder.""" 16from typing import Optional 17import ast 18import inspect 19from textwrap import dedent 20 21from mindspore.nn import Cell 22from .symbol_tree import SymbolTree 23from ..parsers import Parser, ParserRegister 24from ..ast_helpers import AstFlattener 25 26 27class SymbolTreeBuilder: 28 """ 29 `SymbolTreeBuilder` for building a SymbolTree from network. 30 31 Args: 32 network (Cell): An instance of Cell represents a network from which SymbolTree will be built. 33 """ 34 35 # Entry function of the forward computation process 36 entry_functions = ["construct"] 37 38 def __init__(self, network: Cell): 39 if not isinstance(network, Cell): 40 raise TypeError("Type of network should be Cell, but got ", network) 41 self._origin_net = network 42 network_str = inspect.getsource(type(network)) 43 self._ast_root: ast.Module = ast.parse(dedent(network_str)) 44 self._root_tree: Optional[SymbolTree] = None 45 46 @staticmethod 47 def ast_transform(ast_root: ast.AST) -> ast.AST: 48 """ 49 Optimize ast before parse. 50 51 Args: 52 ast_root (ast.AST): An instance of ast to be optimized. 53 54 Returns: 55 An instance of ast been optimized. 56 """ 57 ast_root = AstFlattener().transform(ast_root, SymbolTreeBuilder.entry_functions) 58 return ast_root 59 60 def build(self) -> SymbolTree: 61 """ 62 Build SymbolTree. 63 64 Returns: 65 An instance of SymbolTree. 66 """ 67 68 self._ast_root = SymbolTreeBuilder.ast_transform(self._ast_root) 69 if not isinstance(self._ast_root, ast.Module): 70 raise TypeError("Type of ast_root should be ast.Module, but got ", self._ast_root) 71 self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root) 72 parser: Parser = ParserRegister.instance().get_parser(ast.Module) 73 parser.process(self._root_tree, self._ast_root, None) 74 ast.fix_missing_locations(self._root_tree.get_module_ast()) 75 self._root_tree.finish_build() 76 return self._root_tree 77