• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Code Trace Analyzer utils."""
16import os
17import time
18import types
19import re
20import inspect
21from mindspore import nn
22
23
24class CodeTraceAnalyzer:
25    """
26    Code Trace Analyzer.
27
28    Args:
29        obj (Cell, Function): The obj.
30        save_graphs_path (str): The path of saved ir files.
31        ir_file (str): The ir files to be read. Default: execute.
32    """
33
34    def __init__(self, obj, save_graphs_path, ir_file="execute"):
35        self.obj = obj
36        self.save_graphs_path = save_graphs_path
37        self.ir_content = ""
38        self.ir_file_name = ir_file
39        self.code_lines = 0
40        self.ignore_code_lines = 0
41        self.traced_code_lines = 0
42        self.not_traced_codes = []
43        self.accuracy = 0.0
44        self.analyzed = False
45        self.dir_black_lists = ["mindspore/nn", "mindspore/ops", "mindspore/train"]
46        self.cost_time = 0
47        self.extra_fns = []
48
49    @staticmethod
50    def skip_check(line):
51        line_strip = line.strip()
52        if not line_strip:  # blank line
53            return True
54        if line_strip[0] == '#':  # comment
55            return True
56        if line_strip[0] == '@':  # decorator
57            return True
58        if len(line_strip) > 3 and line_strip[0:4] == "def ":  # function define
59            return True
60
61        return False
62
63    def add_functions(self, *functions):
64        """Add more functions those are not top functions to analyze."""
65        for fn in functions:
66            if not isinstance(fn, (types.FunctionType, types.MethodType)):
67                raise ValueError(f"{fn} must be a Function")
68            self.extra_fns.append(fn)
69
70    def analyze(self):
71        """Start to analyze the code trace accuracy and return the accuracy."""
72        if self.analyzed:
73            raise ValueError(f"analyze() can only call once.")
74
75        start_time = time.time()
76        self._read_ir_content()
77        if isinstance(self.obj, nn.Cell):
78            self._check_net(self.obj)
79        elif isinstance(self.obj, (types.FunctionType, types.MethodType)):
80            self._check_function(self.obj)
81        else:
82            raise ValueError(f"Obj {self.obj} muse be a Cell or Function.")
83
84        for fn in self.extra_fns:
85            self._check_function(fn)
86
87        self.analyzed = True
88        self.accuracy = self.traced_code_lines / (self.code_lines - self.ignore_code_lines)
89        self.cost_time = time.time() - start_time
90        return self.accuracy
91
92    def report_analysis(self):
93        """Report the analysis."""
94        if not self.analyzed:
95            print("Please run analyze() success first.")
96            return
97
98        print("\n------Code Trace Analysis-------")
99        print(f"The code trace accuracy is {self.accuracy}")
100        print(
101            f"All of code lines is {self.code_lines}, ignored code lines is {self.ignore_code_lines}. "
102            f"And there are {self.traced_code_lines} of codes appeared in the ir file: {self.ir_file_name}")
103        print(f"#analyze() cost time: {self.cost_time}s")
104        if self.not_traced_codes:
105            print(f"Below codes are not traced in ir file:")
106            for index, line in enumerate(self.not_traced_codes):
107                print(f"[{index}] {line}")
108
109    def _read_ir_content(self):
110        """Get the content of the last ir file"""
111        ir_files = map(lambda f: os.path.join(self.save_graphs_path, f),
112                       filter(lambda f: re.match(rf'\d+_{self.ir_file_name}_\d+.ir', f),
113                              os.listdir(self.save_graphs_path)))
114        file_name = max(ir_files, key=os.path.getctime)
115        with open(os.path.join(file_name), 'r') as f:
116            self.ir_content = f.read()
117        self.ir_file_name = file_name
118
119    def _check_lines(self, fn):
120        fn_file_name: str = fn.__code__.co_filename
121        for item in self.dir_black_lists:
122            if item in fn_file_name:
123                return
124
125        lines, offset = inspect.getsourcelines(fn)
126        for index, line in enumerate(lines):
127            line = line.replace('\n', '').replace('\r', '') \
128                .replace('[', r'\[').replace(']', r'\]') \
129                .replace('(', r'\(').replace(')', r'\)') \
130                .replace('{', r'\{').replace('}', r'\}') \
131                .replace('.', r'\.').replace('*', r'\*').replace('+', r'\+')
132            if self.skip_check(line):
133                continue
134
135            self.code_lines += 1
136
137            if "<IGNORE>" in line:
138                self.ignore_code_lines += 1
139                continue
140
141            line = f"In file {fn_file_name}:{offset + index}.*/{line}/"
142            if re.search(line, self.ir_content):
143                self.traced_code_lines += 1
144            else:
145                self.not_traced_codes.append(line)
146
147    def _check_net(self, cell):
148        """Recursively check the cell and its sub cell except the mindspore inner cells"""
149        fn = getattr(cell, "construct")
150        fn = inspect.unwrap(fn)
151        self._check_lines(fn)
152
153        for sub_cell in cell.cells():
154            self._check_net(sub_cell)
155
156    def _check_function(self, fn):
157        """"Only Check the given function"""
158        fn = inspect.unwrap(fn)
159        self._check_lines(fn)
160