1#!/usr/bin/env python3 2"""A script to generate FileCheck statements for mlir unit tests. 3 4This script is a utility to add FileCheck patterns to an mlir file. 5 6NOTE: The input .mlir is expected to be the output from the parser, not a 7stripped down variant. 8 9Example usage: 10$ generate-test-checks.py foo.mlir 11$ mlir-opt foo.mlir -transformation | generate-test-checks.py 12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir 13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i 14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' 15 16The script will heuristically generate CHECK/CHECK-LABEL commands for each line 17within the file. By default this script will also try to insert string 18substitution blocks for all SSA value names. If --source file is specified, the 19script will attempt to insert the generated CHECKs to the source file by looking 20for line positions matched by --source_delim_regex. 21 22The script is designed to make adding checks to a test case fast, it is *not* 23designed to be authoritative about what constitutes a good test! 24""" 25 26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 27# See https://llvm.org/LICENSE.txt for license information. 28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 29 30import argparse 31import os # Used to advertise this file's name ("autogenerated_note"). 32import re 33import sys 34 35ADVERT = '// NOTE: Assertions have been autogenerated by ' 36 37# Regex command to match an SSA identifier. 38SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' 39SSA_RE = re.compile(SSA_RE_STR) 40 41 42# Class used to generate and manage string substitution blocks for SSA value 43# names. 44class SSAVariableNamer: 45 46 def __init__(self): 47 self.scopes = [] 48 self.name_counter = 0 49 50 # Generate a substitution name for the given ssa value name. 51 def generate_name(self, ssa_name): 52 variable = 'VAL_' + str(self.name_counter) 53 self.name_counter += 1 54 self.scopes[-1][ssa_name] = variable 55 return variable 56 57 # Push a new variable name scope. 58 def push_name_scope(self): 59 self.scopes.append({}) 60 61 # Pop the last variable name scope. 62 def pop_name_scope(self): 63 self.scopes.pop() 64 65 # Return the level of nesting (number of pushed scopes). 66 def num_scopes(self): 67 return len(self.scopes) 68 69 # Reset the counter. 70 def clear_counter(self): 71 self.name_counter = 0 72 73 74# Process a line of input that has been split at each SSA identifier '%'. 75def process_line(line_chunks, variable_namer): 76 output_line = '' 77 78 # Process the rest that contained an SSA value name. 79 for chunk in line_chunks: 80 m = SSA_RE.match(chunk) 81 ssa_name = m.group(0) 82 83 # Check if an existing variable exists for this name. 84 variable = None 85 for scope in variable_namer.scopes: 86 variable = scope.get(ssa_name) 87 if variable is not None: 88 break 89 90 # If one exists, then output the existing name. 91 if variable is not None: 92 output_line += '%[[' + variable + ']]' 93 else: 94 # Otherwise, generate a new variable. 95 variable = variable_namer.generate_name(ssa_name) 96 output_line += '%[[' + variable + ':.*]]' 97 98 # Append the non named group. 99 output_line += chunk[len(ssa_name):] 100 101 return output_line.rstrip() + '\n' 102 103 104# Process the source file lines. The source file doesn't have to be .mlir. 105def process_source_lines(source_lines, note, args): 106 source_split_re = re.compile(args.source_delim_regex) 107 108 source_segments = [[]] 109 for line in source_lines: 110 # Remove previous note. 111 if line == note: 112 continue 113 # Remove previous CHECK lines. 114 if line.find(args.check_prefix) != -1: 115 continue 116 # Segment the file based on --source_delim_regex. 117 if source_split_re.search(line): 118 source_segments.append([]) 119 120 source_segments[-1].append(line + '\n') 121 return source_segments 122 123 124# Pre-process a line of input to remove any character sequences that will be 125# problematic with FileCheck. 126def preprocess_line(line): 127 # Replace any double brackets, '[[' with escaped replacements. '[[' 128 # corresponds to variable names in FileCheck. 129 output_line = line.replace('[[', '{{\\[\\[}}') 130 131 # Replace any single brackets that are followed by an SSA identifier, the 132 # identifier will be replace by a variable; Creating the same situation as 133 # above. 134 output_line = output_line.replace('[%', '{{\\[}}%') 135 136 return output_line 137 138 139def main(): 140 parser = argparse.ArgumentParser( 141 description=__doc__, formatter_class=argparse.RawTextHelpFormatter) 142 parser.add_argument( 143 '--check-prefix', default='CHECK', help='Prefix to use from check file.') 144 parser.add_argument( 145 '-o', 146 '--output', 147 nargs='?', 148 type=argparse.FileType('w'), 149 default=None) 150 parser.add_argument( 151 'input', 152 nargs='?', 153 type=argparse.FileType('r'), 154 default=sys.stdin) 155 parser.add_argument( 156 '--source', type=str, 157 help='Print each CHECK chunk before each delimeter line in the source' 158 'file, respectively. The delimeter lines are identified by ' 159 '--source_delim_regex.') 160 parser.add_argument('--source_delim_regex', type=str, default='func @') 161 parser.add_argument( 162 '--starts_from_scope', type=int, default=1, 163 help='Omit the top specified level of content. For example, by default ' 164 'it omits "module {"') 165 parser.add_argument('-i', '--inplace', action='store_true', default=False) 166 167 args = parser.parse_args() 168 169 # Open the given input file. 170 input_lines = [l.rstrip() for l in args.input] 171 args.input.close() 172 173 # Generate a note used for the generated check file. 174 script_name = os.path.basename(__file__) 175 autogenerated_note = (ADVERT + 'utils/' + script_name) 176 177 source_segments = None 178 if args.source: 179 source_segments = process_source_lines( 180 [l.rstrip() for l in open(args.source, 'r')], 181 autogenerated_note, 182 args 183 ) 184 185 if args.inplace: 186 assert args.output is None 187 output = open(args.source, 'w') 188 elif args.output is None: 189 output = sys.stdout 190 else: 191 output = args.output 192 193 output_segments = [[]] 194 # A map containing data used for naming SSA value names. 195 variable_namer = SSAVariableNamer() 196 for input_line in input_lines: 197 if not input_line: 198 continue 199 lstripped_input_line = input_line.lstrip() 200 201 # Lines with blocks begin with a ^. These lines have a trailing comment 202 # that needs to be stripped. 203 is_block = lstripped_input_line[0] == '^' 204 if is_block: 205 input_line = input_line.rsplit('//', 1)[0].rstrip() 206 207 cur_level = variable_namer.num_scopes() 208 209 # If the line starts with a '}', pop the last name scope. 210 if lstripped_input_line[0] == '}': 211 variable_namer.pop_name_scope() 212 cur_level = variable_namer.num_scopes() 213 214 # If the line ends with a '{', push a new name scope. 215 if input_line[-1] == '{': 216 variable_namer.push_name_scope() 217 if cur_level == args.starts_from_scope: 218 output_segments.append([]) 219 220 # Omit lines at the near top level e.g. "module {". 221 if cur_level < args.starts_from_scope: 222 continue 223 224 if len(output_segments[-1]) == 0: 225 variable_namer.clear_counter() 226 227 # Preprocess the input to remove any sequences that may be problematic with 228 # FileCheck. 229 input_line = preprocess_line(input_line) 230 231 # Split the line at the each SSA value name. 232 ssa_split = input_line.split('%') 233 234 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. 235 if len(output_segments[-1]) != 0 or not ssa_split[0]: 236 output_line = '// ' + args.check_prefix + ': ' 237 # Pad to align with the 'LABEL' statements. 238 output_line += (' ' * len('-LABEL')) 239 240 # Output the first line chunk that does not contain an SSA name. 241 output_line += ssa_split[0] 242 243 # Process the rest of the input line. 244 output_line += process_line(ssa_split[1:], variable_namer) 245 246 else: 247 # Output the first line chunk that does not contain an SSA name for the 248 # label. 249 output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n' 250 251 # Process the rest of the input line on separate check lines. 252 for argument in ssa_split[1:]: 253 output_line += '// ' + args.check_prefix + '-SAME: ' 254 255 # Pad to align with the original position in the line. 256 output_line += ' ' * len(ssa_split[0]) 257 258 # Process the rest of the line. 259 output_line += process_line([argument], variable_namer) 260 261 # Append the output line. 262 output_segments[-1].append(output_line) 263 264 output.write(autogenerated_note + '\n') 265 266 # Write the output. 267 if source_segments: 268 assert len(output_segments) == len(source_segments) 269 for check_segment, source_segment in zip(output_segments, source_segments): 270 for line in check_segment: 271 output.write(line) 272 for line in source_segment: 273 output.write(line) 274 else: 275 for segment in output_segments: 276 output.write('\n') 277 for output_line in segment: 278 output.write(output_line) 279 output.write('\n') 280 output.close() 281 282 283if __name__ == '__main__': 284 main() 285