• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Convert ast node to other type."""
16from typing import Union, List
17import ast
18import sys
19
20from mindspore import log as logger
21from ..api.scoped_value import ScopedValue, ValueType
22
23if sys.version_info >= (3, 9):
24    import ast as astunparse # pylint: disable=reimported, ungrouped-imports
25else:
26    import astunparse
27
28AST_CONSTANTS = (ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes)
29
30
31class AstConverter():
32    """
33    Get information from ast node and convert to other type.
34    """
35
36    @staticmethod
37    def get_ast_constant_value(node: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Str, ast.Bytes]):
38        """Get value from ast constant"""
39        if isinstance(node, (ast.Constant, ast.NameConstant)):
40            return node.value
41        if isinstance(node, ast.Num):
42            return node.n
43        if isinstance(node, (ast.Str, ast.Bytes)):
44            return node.s
45        raise ValueError(f"For get_ast_constant_value, node cannot be {type(node)}")
46
47    @staticmethod
48    def create_scopedvalue(node: ast.AST) -> ScopedValue:
49        """
50        Create ScopedValue from an ast node.
51
52        Args:
53            node (ast.AST): An ast node.
54
55        Returns:
56            An instance of ScopedValue.
57        """
58        if isinstance(node, ast.Name):
59            return ScopedValue.create_naming_value(node.id)
60        if isinstance(node, ast.Attribute):
61            scope = node.value
62            if not isinstance(scope, ast.Name):
63                node_str = astunparse.unparse(node).strip()
64                logger.info(f"When creating scopedvalue for '{node_str}', value of ast.Attribute should be ast.Name, "
65                            f"but got ast type '{type(scope).__name__}'")
66                return ScopedValue(ValueType.UnsupportedValue, "", node_str)
67            return ScopedValue.create_naming_value(node.attr, scope.id)
68        if isinstance(node, (ast.List, ast.Tuple)):
69            return AstConverter.create_scopedvalue_from_list(node.elts)
70        if isinstance(node, AST_CONSTANTS):
71            value = AstConverter.get_ast_constant_value(node)
72            return ScopedValue.create_variable_value(value)
73        node_str = astunparse.unparse(node).strip()
74        logger.info(f"For '{node_str}', type '{type(node).__name__}' is not supported for ScopedValue now.")
75        return ScopedValue(ValueType.UnsupportedValue, "", node_str)
76
77    @staticmethod
78    def create_scopedvalue_from_list(ast_list: List[ast.AST]) -> ScopedValue:
79        """
80        Create ScopedValue from a list of ast nodes.
81
82        Args:
83            ast_list (List[Union[ast.Constant, ast.Name, ast.Attribute]]): A list of ast nodes.
84
85        Returns:
86            An instance of ScopedValue.
87        """
88        tuple_values = []
89        for tuple_elt in ast_list:
90            if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
91                node_str = astunparse.unparse(tuple_elt).strip()
92                logger.info(f"When create scopedvalue for '{node_str}' only support (ast.Constant, ast.Name, "
93                            f"ast.Attribute) as elts of ast.Tuple, but got ast type {type(tuple_elt).__name__}")
94                return ScopedValue(ValueType.UnsupportedValue, "", node_str)
95            if isinstance(tuple_elt, ast.Constant):
96                tuple_values.append(tuple_elt.value)
97            elif isinstance(tuple_elt, ast.Name):
98                tuple_values.append(tuple_elt.id)
99            elif isinstance(tuple_elt, ast.Attribute):
100                tuple_values.append("".join([tuple_elt.value.id, '.', tuple_elt.attr]))
101        return ScopedValue.create_variable_value(tuple(tuple_values))
102
103    @staticmethod
104    def get_ast_name(ast_node: Union[ast.Name, ast.Attribute]) -> str:
105        """Get name from ast.Name or ast.Attribute"""
106        if isinstance(ast_node, ast.Name):
107            return ast_node.id
108        if isinstance(ast_node, ast.Attribute):
109            return ast_node.attr
110        return ""
111
112    @staticmethod
113    def ast_tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
114        """ check whether each element's type in tuple is supported by scopled value. """
115        for elt in value.elts:
116            if not isinstance(elt, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)):
117                return False
118        return True
119
120    @staticmethod
121    def ast_dict_support_scopledvalue(ast_dict: ast.Dict) -> bool:
122        """ check whether each element's type in dict is supported by scopled value. """
123        for key in ast_dict.keys:
124            if not (isinstance(key, ast.Constant) and isinstance(key.value, str)):
125                return False
126        for value in ast_dict.values:
127            if not isinstance(value, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)):
128                return False
129        return True
130
131    @staticmethod
132    def get_ast_target_elems(ast_target: ast.AST, convert_to_str: bool = False):
133        """Get elements in ast"""
134        target_ast_elems = []
135        if isinstance(ast_target, (ast.Tuple, ast.List)):
136            for ast_elem in ast_target.elts:
137                target_ast_elems.extend(AstConverter.get_ast_target_elems(ast_elem))
138        else:
139            if convert_to_str:
140                target_ast_elems.append(astunparse.unparse(ast_target).strip())
141            else:
142                target_ast_elems.append(ast_target)
143        return target_ast_elems
144