1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Convert ast node to other type.""" 16from typing import Union, List 17import ast 18import sys 19 20from mindspore import log as logger 21from ..api.scoped_value import ScopedValue, ValueType 22 23if sys.version_info >= (3, 9): 24 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 25else: 26 import astunparse 27 28AST_CONSTANTS = (ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes) 29 30 31class AstConverter(): 32 """ 33 Get information from ast node and convert to other type. 34 """ 35 36 @staticmethod 37 def get_ast_constant_value(node: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Str, ast.Bytes]): 38 """Get value from ast constant""" 39 if isinstance(node, (ast.Constant, ast.NameConstant)): 40 return node.value 41 if isinstance(node, ast.Num): 42 return node.n 43 if isinstance(node, (ast.Str, ast.Bytes)): 44 return node.s 45 raise ValueError(f"For get_ast_constant_value, node cannot be {type(node)}") 46 47 @staticmethod 48 def create_scopedvalue(node: ast.AST) -> ScopedValue: 49 """ 50 Create ScopedValue from an ast node. 51 52 Args: 53 node (ast.AST): An ast node. 54 55 Returns: 56 An instance of ScopedValue. 57 """ 58 if isinstance(node, ast.Name): 59 return ScopedValue.create_naming_value(node.id) 60 if isinstance(node, ast.Attribute): 61 scope = node.value 62 if not isinstance(scope, ast.Name): 63 node_str = astunparse.unparse(node).strip() 64 logger.info(f"When creating scopedvalue for '{node_str}', value of ast.Attribute should be ast.Name, " 65 f"but got ast type '{type(scope).__name__}'") 66 return ScopedValue(ValueType.UnsupportedValue, "", node_str) 67 return ScopedValue.create_naming_value(node.attr, scope.id) 68 if isinstance(node, (ast.List, ast.Tuple)): 69 return AstConverter.create_scopedvalue_from_list(node.elts) 70 if isinstance(node, AST_CONSTANTS): 71 value = AstConverter.get_ast_constant_value(node) 72 return ScopedValue.create_variable_value(value) 73 node_str = astunparse.unparse(node).strip() 74 logger.info(f"For '{node_str}', type '{type(node).__name__}' is not supported for ScopedValue now.") 75 return ScopedValue(ValueType.UnsupportedValue, "", node_str) 76 77 @staticmethod 78 def create_scopedvalue_from_list(ast_list: List[ast.AST]) -> ScopedValue: 79 """ 80 Create ScopedValue from a list of ast nodes. 81 82 Args: 83 ast_list (List[Union[ast.Constant, ast.Name, ast.Attribute]]): A list of ast nodes. 84 85 Returns: 86 An instance of ScopedValue. 87 """ 88 tuple_values = [] 89 for tuple_elt in ast_list: 90 if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)): 91 node_str = astunparse.unparse(tuple_elt).strip() 92 logger.info(f"When create scopedvalue for '{node_str}' only support (ast.Constant, ast.Name, " 93 f"ast.Attribute) as elts of ast.Tuple, but got ast type {type(tuple_elt).__name__}") 94 return ScopedValue(ValueType.UnsupportedValue, "", node_str) 95 if isinstance(tuple_elt, ast.Constant): 96 tuple_values.append(tuple_elt.value) 97 elif isinstance(tuple_elt, ast.Name): 98 tuple_values.append(tuple_elt.id) 99 elif isinstance(tuple_elt, ast.Attribute): 100 tuple_values.append("".join([tuple_elt.value.id, '.', tuple_elt.attr])) 101 return ScopedValue.create_variable_value(tuple(tuple_values)) 102 103 @staticmethod 104 def get_ast_name(ast_node: Union[ast.Name, ast.Attribute]) -> str: 105 """Get name from ast.Name or ast.Attribute""" 106 if isinstance(ast_node, ast.Name): 107 return ast_node.id 108 if isinstance(ast_node, ast.Attribute): 109 return ast_node.attr 110 return "" 111 112 @staticmethod 113 def ast_tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool: 114 """ check whether each element's type in tuple is supported by scopled value. """ 115 for elt in value.elts: 116 if not isinstance(elt, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)): 117 return False 118 return True 119 120 @staticmethod 121 def ast_dict_support_scopledvalue(ast_dict: ast.Dict) -> bool: 122 """ check whether each element's type in dict is supported by scopled value. """ 123 for key in ast_dict.keys: 124 if not (isinstance(key, ast.Constant) and isinstance(key.value, str)): 125 return False 126 for value in ast_dict.values: 127 if not isinstance(value, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)): 128 return False 129 return True 130 131 @staticmethod 132 def get_ast_target_elems(ast_target: ast.AST, convert_to_str: bool = False): 133 """Get elements in ast""" 134 target_ast_elems = [] 135 if isinstance(ast_target, (ast.Tuple, ast.List)): 136 for ast_elem in ast_target.elts: 137 target_ast_elems.extend(AstConverter.get_ast_target_elems(ast_elem)) 138 else: 139 if convert_to_str: 140 target_ast_elems.append(astunparse.unparse(ast_target).strip()) 141 else: 142 target_ast_elems.append(ast_target) 143 return target_ast_elems 144