1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Code Trace Analyzer utils.""" 16import os 17import time 18import types 19import re 20import inspect 21from mindspore import nn 22 23 24class CodeTraceAnalyzer: 25 """ 26 Code Trace Analyzer. 27 28 Args: 29 obj (Cell, Function): The obj. 30 save_graphs_path (str): The path of saved ir files. 31 ir_file (str): The ir files to be read. Default: execute. 32 """ 33 34 def __init__(self, obj, save_graphs_path, ir_file="execute"): 35 self.obj = obj 36 self.save_graphs_path = save_graphs_path 37 self.ir_content = "" 38 self.ir_file_name = ir_file 39 self.code_lines = 0 40 self.ignore_code_lines = 0 41 self.traced_code_lines = 0 42 self.not_traced_codes = [] 43 self.accuracy = 0.0 44 self.analyzed = False 45 self.dir_black_lists = ["mindspore/nn", "mindspore/ops", "mindspore/train"] 46 self.cost_time = 0 47 self.extra_fns = [] 48 49 @staticmethod 50 def skip_check(line): 51 line_strip = line.strip() 52 if not line_strip: # blank line 53 return True 54 if line_strip[0] == '#': # comment 55 return True 56 if line_strip[0] == '@': # decorator 57 return True 58 if len(line_strip) > 3 and line_strip[0:4] == "def ": # function define 59 return True 60 61 return False 62 63 def add_functions(self, *functions): 64 """Add more functions those are not top functions to analyze.""" 65 for fn in functions: 66 if not isinstance(fn, (types.FunctionType, types.MethodType)): 67 raise ValueError(f"{fn} must be a Function") 68 self.extra_fns.append(fn) 69 70 def analyze(self): 71 """Start to analyze the code trace accuracy and return the accuracy.""" 72 if self.analyzed: 73 raise ValueError(f"analyze() can only call once.") 74 75 start_time = time.time() 76 self._read_ir_content() 77 if isinstance(self.obj, nn.Cell): 78 self._check_net(self.obj) 79 elif isinstance(self.obj, (types.FunctionType, types.MethodType)): 80 self._check_function(self.obj) 81 else: 82 raise ValueError(f"Obj {self.obj} muse be a Cell or Function.") 83 84 for fn in self.extra_fns: 85 self._check_function(fn) 86 87 self.analyzed = True 88 self.accuracy = self.traced_code_lines / (self.code_lines - self.ignore_code_lines) 89 self.cost_time = time.time() - start_time 90 return self.accuracy 91 92 def report_analysis(self): 93 """Report the analysis.""" 94 if not self.analyzed: 95 print("Please run analyze() success first.") 96 return 97 98 print("\n------Code Trace Analysis-------") 99 print(f"The code trace accuracy is {self.accuracy}") 100 print( 101 f"All of code lines is {self.code_lines}, ignored code lines is {self.ignore_code_lines}. " 102 f"And there are {self.traced_code_lines} of codes appeared in the ir file: {self.ir_file_name}") 103 print(f"#analyze() cost time: {self.cost_time}s") 104 if self.not_traced_codes: 105 print(f"Below codes are not traced in ir file:") 106 for index, line in enumerate(self.not_traced_codes): 107 print(f"[{index}] {line}") 108 109 def _read_ir_content(self): 110 """Get the content of the last ir file""" 111 ir_files = map(lambda f: os.path.join(self.save_graphs_path, f), 112 filter(lambda f: re.match(rf'\d+_{self.ir_file_name}_\d+.ir', f), 113 os.listdir(self.save_graphs_path))) 114 file_name = max(ir_files, key=os.path.getctime) 115 with open(os.path.join(file_name), 'r') as f: 116 self.ir_content = f.read() 117 self.ir_file_name = file_name 118 119 def _check_lines(self, fn): 120 fn_file_name: str = fn.__code__.co_filename 121 for item in self.dir_black_lists: 122 if item in fn_file_name: 123 return 124 125 lines, offset = inspect.getsourcelines(fn) 126 for index, line in enumerate(lines): 127 line = line.replace('\n', '').replace('\r', '') \ 128 .replace('[', r'\[').replace(']', r'\]') \ 129 .replace('(', r'\(').replace(')', r'\)') \ 130 .replace('{', r'\{').replace('}', r'\}') \ 131 .replace('.', r'\.').replace('*', r'\*').replace('+', r'\+') 132 if self.skip_check(line): 133 continue 134 135 self.code_lines += 1 136 137 if "<IGNORE>" in line: 138 self.ignore_code_lines += 1 139 continue 140 141 line = f"In file {fn_file_name}:{offset + index}.*/{line}/" 142 if re.search(line, self.ir_content): 143 self.traced_code_lines += 1 144 else: 145 self.not_traced_codes.append(line) 146 147 def _check_net(self, cell): 148 """Recursively check the cell and its sub cell except the mindspore inner cells""" 149 fn = getattr(cell, "construct") 150 fn = inspect.unwrap(fn) 151 self._check_lines(fn) 152 153 for sub_cell in cell.cells(): 154 self._check_net(sub_cell) 155 156 def _check_function(self, fn): 157 """"Only Check the given function""" 158 fn = inspect.unwrap(fn) 159 self._check_lines(fn) 160