1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Replacing specific symbol name with another symbol name in specific scope.""" 16 17from typing import Any 18import ast 19 20 21class AstReplacer(ast.NodeTransformer): 22 """ 23 Replace all specific symbol name in specific scope with another symbol name. 24 25 Args: 26 node (ast.AST): An instance of ast node as replace scope. 27 """ 28 29 def __init__(self, node: ast.AST): 30 self._scope = node 31 self._src = "" 32 self._dst = "" 33 self._trace = [] 34 35 def visit_ClassDef(self, node: ast.ClassDef) -> Any: 36 """ 37 An override method, call back when visiting an ast.ClassDef node. 38 39 Args: 40 node (ast.ClassDef): An instance of ast.ClassDef which is visited currently. 41 """ 42 43 if node.name == self._src: 44 node.name = self._dst 45 self._trace.append((node, "name", self._src, self._dst)) 46 return self.generic_visit(node) 47 48 def visit_Name(self, node: ast.Name) -> Any: 49 """ 50 An override method, call back when visiting an ast.Name node. 51 52 Args: 53 node (ast.Name): An instance of ast.Name which is visited currently. 54 """ 55 56 if node.id == self._src: 57 node.id = self._dst 58 self._trace.append((node, "id", self._src, self._dst)) 59 return self.generic_visit(node) 60 61 def replace_all(self, src: str, dst: str): 62 """ 63 Replace all matched symbol to new symbol name. 64 65 Args: 66 src (str): Target symbol name to be replaced out. 67 dst (str): New symbol name to be replaced in. 68 """ 69 70 self._src = src 71 self._dst = dst 72 self.visit(self._scope) 73 74 def undo_all(self): 75 """Undo all replace-actions applied on current scope.""" 76 77 for trace in self._trace: 78 setattr(trace[0], trace[1], trace[2]) 79 self._trace.clear() 80