• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Replacing specific symbol name with another symbol name in specific scope."""
16
17from typing import Any
18import ast
19
20
21class AstReplacer(ast.NodeTransformer):
22    """
23    Replace all specific symbol name in specific scope with another symbol name.
24
25    Args:
26        node (ast.AST): An instance of ast node as replace scope.
27    """
28
29    def __init__(self, node: ast.AST):
30        self._scope = node
31        self._src = ""
32        self._dst = ""
33        self._trace = []
34
35    def visit_ClassDef(self, node: ast.ClassDef) -> Any:
36        """
37        An override method, call back when visiting an ast.ClassDef node.
38
39        Args:
40            node (ast.ClassDef): An instance of ast.ClassDef which is visited currently.
41        """
42
43        if node.name == self._src:
44            node.name = self._dst
45            self._trace.append((node, "name", self._src, self._dst))
46        return self.generic_visit(node)
47
48    def visit_Name(self, node: ast.Name) -> Any:
49        """
50        An override method, call back when visiting an ast.Name node.
51
52        Args:
53            node (ast.Name): An instance of ast.Name which is visited currently.
54        """
55
56        if node.id == self._src:
57            node.id = self._dst
58            self._trace.append((node, "id", self._src, self._dst))
59        return self.generic_visit(node)
60
61    def replace_all(self, src: str, dst: str):
62        """
63        Replace all matched symbol to new symbol name.
64
65        Args:
66            src (str): Target symbol name to be replaced out.
67            dst (str): New symbol name to be replaced in.
68        """
69
70        self._src = src
71        self._dst = dst
72        self.visit(self._scope)
73
74    def undo_all(self):
75        """Undo all replace-actions applied on current scope."""
76
77        for trace in self._trace:
78            setattr(trace[0], trace[1], trace[2])
79        self._trace.clear()
80