• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Ast optimizer for flatten recursive call."""
16
17from typing import Any, Tuple, List, Dict, Union
18import keyword
19import ast
20import copy
21
22from mindspore import log as logger
23
24FLATTEN_BLACK_LIST = ["set_vertex_attr"]
25
26
27class AstFlattener(ast.NodeTransformer):
28    """Ast optimizer for flatten recursive call."""
29
30    # Record origin test ast, used to judge direction of static if control flow.
31    ast_if_test_cache: Dict[ast.If, ast.AST] = {}
32
33    def __init__(self):
34        """
35        Constructor of AstFlattener.
36
37        Returns:
38            An instance of ast optimizer for flatten recursive call.
39        """
40        self._flatten_table: dict = {
41            ast.Return: ["value"],
42            ast.Call: ["func", "args", "keywords"],
43            ast.BinOp: ["left", "right"],
44            ast.BoolOp: ["values"],
45            ast.UnaryOp: ["operand"],
46            ast.Compare: ["left", "comparators"],
47            ast.If: ["test"],
48            ast.For: ["iter"],
49            ast.Tuple: ["elts"],
50            ast.List: ["elts"],
51        }
52        self._transform_functions = []
53        self._symbol_tree = None # Used to get unique name
54
55    @staticmethod
56    def _check_flatten_black_list(node: ast.AST):
57        """Check whether node in flatten black list"""
58        func_name = ""
59        # Get func name of node
60        if isinstance(node, ast.Call):
61            if isinstance(node.func, ast.Name):
62                func_name = node.func.id
63            elif isinstance(node.func, ast.Attribute):
64                func_name = node.func.attr
65        # Check func name of node
66        if func_name and func_name in FLATTEN_BLACK_LIST:
67            return True
68        return False
69
70    @staticmethod
71    def _flatten_continuous_assign(ast_body: List[ast.AST]):
72        """
73        Flatten ast.Assign with continuous targets.
74        """
75        for pos, ast_node in enumerate(ast_body):
76            if not isinstance(ast_node, ast.Assign):
77                continue
78            if not len(ast_node.targets) > 1:
79                continue
80            for idx, ast_target in enumerate(ast_node.targets[:-1]):
81                new_assign = ast.Assign(targets=[ast_target], value=ast_node.targets[idx + 1])
82                ast_body.insert(pos + idx + 1, new_assign)
83            ast_node.targets = [ast_node.targets[-1]]
84
85    @staticmethod
86    def _save_target_names(ast_body: List[ast.AST]):
87        """Saving target names in ast_body before getting unique names."""
88        target_names = []
89        for child in ast_body:
90            if not isinstance(child, ast.Assign):
91                continue
92            targets = child.targets
93            for target in targets:
94                if isinstance(target, ast.Name) and target.id not in target_names:
95                    target_names.append(target.id)
96                elif isinstance(target, (ast.Tuple, ast.List)):
97                    # get target names from list recursively
98                    ast_queue = [target.elts]
99                    while ast_queue:
100                        elt = ast_queue.pop()
101                        if isinstance(elt, ast.Name) and elt.id not in target_names:
102                            target_names.append(elt.id)
103                        elif isinstance(elt, (ast.Tuple, ast.List)):
104                            ast_queue.extend(elt.elts)
105                        elif isinstance(elt, (list, tuple)):
106                            ast_queue.extend(elt)
107        return target_names
108
109    def _generate_target_name(self, node: ast.AST, target_names):
110        """Generate unique target name."""
111        if isinstance(node, ast.Call):
112            func = node.func
113            if isinstance(func, ast.Name):
114                target_name = func.id + "_var"
115            elif isinstance(func, ast.Attribute):
116                target_name = func.attr + "_var"
117            else:
118                logger.debug("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
119                target_name = "function_var"
120        elif isinstance(node, ast.Return):
121            target_name = "return_value"
122        elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
123            target_name = type(node.op).__name__.lower() + "_var"
124        elif isinstance(node, ast.Tuple):
125            target_name = type(node).__name__.lower() + "_var"
126        elif isinstance(node, ast.Name):
127            target_name = node.id
128        elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
129            target_name = f"{node.value.id}_{node.attr}"
130        else:
131            logger.debug("unhandled type of node while generating new target name: %s ", type(node))
132            target_name = type(node).__name__.lower() + "_var"
133        # avoid python built-in keyword
134        if keyword.iskeyword(target_name):
135            target_name = target_name + "_var"
136        suffix = 0
137        result = target_name
138        while result in target_names:
139            suffix += 1
140            result = f"{target_name}_{suffix}"
141        if self._symbol_tree:
142            result = self._symbol_tree.unique_name(result)
143        target_names.append(result)
144        return result
145
146    def _create_new_assign_node(self, node: ast.AST, target_names, father_node: ast.AST) \
147            -> Tuple[Union[ast.Name, ast.Attribute], ast.AST]:
148        """Create new assign node to be inserted into ast.FunctionDef."""
149        ast_unflattens = (ast.Name, ast.NameConstant, ast.Constant, ast.Num, ast.Str, ast.Bytes, ast.Ellipsis)
150        if isinstance(node, ast_unflattens):
151            return node, None
152        # ast.Attribute in ast.For will be force flatten
153        # when ast.Attribute is not in ast.For, it's value which is not type of ast.Name will be flatten
154        if isinstance(node, ast.Attribute) and not isinstance(father_node, ast.For):
155            iter_node = node
156            while isinstance(iter_node.value, ast.Attribute):
157                iter_node = iter_node.value
158            if isinstance(iter_node.value, ast.Name):
159                return node, None
160            new_target_name = self._generate_target_name(iter_node.value, target_names)
161            new_node = ast.Attribute(value=ast.Name(id=new_target_name, ctx=ast.Load()),
162                                     attr=iter_node.attr, ctx=iter_node.ctx)
163            return new_node, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=iter_node.value)
164        # flatten nodes
165        new_target_name = self._generate_target_name(node, target_names)
166        new_node = ast.Name(id=new_target_name, ctx=ast.Load())
167        return new_node, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
168
169    def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
170        """Flatten recursive statement according to different node type."""
171        if AstFlattener._check_flatten_black_list(node):
172            return []
173        flatten_config = self._flatten_table.get(type(node))
174        if flatten_config is None:
175            return []
176        results = []
177        for todo_name in flatten_config:
178            todos = getattr(node, todo_name)
179            if isinstance(todos, list):
180                new_list = []
181                for idx, todo in enumerate(todos):
182                    # Starred expression(e.g. *args) cannot be flatten.
183                    if isinstance(todo, ast.Starred):
184                        new_list.append(todo)
185                        continue
186                    # For codes like 'xxx and yyy and zzz', only 'xxx' can be flatten and parsed,
187                    # otherwise executing 'yyy' may raise an exception when 'xxx' is False
188                    if isinstance(node, ast.BoolOp) and isinstance(node.op, ast.And) and idx > 0:
189                        new_list.append(todo)
190                        continue
191                    # ast.keywords are processed individually:
192                    # y = func(key=value) => new_target_name = value & y = func(key=new_target_name)
193                    if isinstance(todo, ast.keyword):
194                        new_node, new_assign = self._create_new_assign_node(todo.value, target_names, node)
195                        if id(new_node) != id(todo.value):
196                            todo.value = new_node
197                            results.append(new_assign)
198                        new_list.append(todo)
199                        continue
200                    new_node, new_assign = self._create_new_assign_node(todo, target_names, node)
201                    if id(new_node) != id(todo):
202                        new_list.append(new_node)
203                        results.append(new_assign)
204                    else:
205                        new_list.append(todo)
206                setattr(node, todo_name, new_list)
207            elif isinstance(todos, dict):
208                new_dict = []
209                for key, value in todos:
210                    new_node, new_assign = self._create_new_assign_node(value, target_names, node)
211                    if id(new_node) != id(value):
212                        new_dict[key] = new_node
213                        results.append(new_assign)
214                    else:
215                        new_dict[key] = value
216                setattr(node, todo_name, new_dict)
217            else:
218                new_node, new_assign = self._create_new_assign_node(todos, target_names, node)
219                if id(new_node) != id(todos):
220                    setattr(node, todo_name, new_node)
221                    results.append(new_assign)
222        return results
223
224    def _visit_ast_bodies(self, ast_body: List[ast.AST]):
225        """Traverse nodes in ast_body and flatten nodes recursive."""
226        # Flatten continuous assign statements in ast_body
227        AstFlattener._flatten_continuous_assign(ast_body)
228        # save target names, used when create new assign ast node
229        target_names = AstFlattener._save_target_names(ast_body)
230        index = len(ast_body) - 1
231        while index >= 0:
232            child = ast_body[index]
233            # Record origin test ast, used to judge direction of static if control flow.
234            if isinstance(child, ast.If) and child not in AstFlattener.ast_if_test_cache:
235                AstFlattener.ast_if_test_cache[child] = copy.deepcopy(child.test)
236
237            stmt = child.value if isinstance(child, (ast.Assign, ast.Expr)) else child
238            results = self._flatten_statement(stmt, target_names)
239            if results:
240                for result in reversed(results):
241                    ast_body.insert(index, result)
242                    index += 1
243            index -= 1
244
245    def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
246        """Traverse nodes in _transform_functions and flatten recursive nodes."""
247        if node.name not in self._transform_functions:
248            return node
249        self._visit_ast_bodies(node.body)
250        return node
251
252    def transform(self, ast_root, transform_functions=None, stree=None):
253        """Interface of AstFlattener."""
254        self._transform_functions = transform_functions if transform_functions else ["construct"]
255        self._symbol_tree = stree
256        ast_root = self.visit(ast_root)
257        ast_root = ast.fix_missing_locations(ast_root)
258        return ast_root
259
260    def transform_control_flow(self, ast_control_flow: Union[ast.If, ast.For, ast.While], stree=None):
261        """Interface of AstFlattener."""
262        self._transform_functions = []
263        self._symbol_tree = stree
264        self._visit_ast_bodies(ast_control_flow.body)
265        if ast_control_flow.orelse:
266            self._visit_ast_bodies(ast_control_flow.orelse)
267        ast_control_flow = ast.fix_missing_locations(ast_control_flow)
268        return ast_control_flow
269