1# 2# Copyright (c) 2020 Valve Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22import re 23import sys 24import os.path 25import struct 26import string 27import copy 28from math import floor 29 30if os.isatty(sys.stdout.fileno()): 31 set_red = "\033[31m" 32 set_green = "\033[1;32m" 33 set_normal = "\033[0m" 34else: 35 set_red = '' 36 set_green = '' 37 set_normal = '' 38 39initial_code = ''' 40def insert_code(code): 41 insert_queue.append(CodeCheck(code)) 42 43def insert_pattern(pattern): 44 insert_queue.append(PatternCheck(pattern)) 45 46def vector_gpr(prefix, name, size, align): 47 insert_code(f'{name} = {name}0') 48 for i in range(size): 49 insert_code(f'{name}{i} = {name}0 + {i}') 50 insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}') 51 insert_code(f'success = {name}0 % {align} == 0') 52 return f'{prefix}[#{name}0:#{name}{size - 1}]' 53 54def sgpr_vector(name, size, align): 55 return vector_gpr('s', name, size, align) 56 57funcs.update({ 58 's64': lambda name: vector_gpr('s', name, 2, 2), 59 's96': lambda name: vector_gpr('s', name, 3, 2), 60 's128': lambda name: vector_gpr('s', name, 4, 4), 61 's256': lambda name: vector_gpr('s', name, 8, 4), 62 's512': lambda name: vector_gpr('s', name, 16, 4), 63}) 64for i in range(2, 14): 65 funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1) 66''' 67 68class Check: 69 def __init__(self, data): 70 self.data = data.rstrip() 71 72 def run(self, state): 73 pass 74 75class CodeCheck(Check): 76 def run(self, state): 77 indent = 0 78 first_line = [l for l in self.data.split('\n') if l.strip() != ''][0] 79 indent_amount = len(first_line) - len(first_line.lstrip()) 80 indent = first_line[:indent_amount] 81 new_lines = [] 82 for line in self.data.split('\n'): 83 if line.strip() == '': 84 new_lines.append('') 85 continue 86 if line[:indent_amount] != indent: 87 state.result.log += 'unexpected indent in code check:\n' 88 state.result.log += self.data + '\n' 89 return False 90 new_lines.append(line[indent_amount:]) 91 code = '\n'.join(new_lines) 92 93 try: 94 exec(code, state.g) 95 state.result.log += state.g['log'] 96 state.g['log'] = '' 97 except BaseException as e: 98 state.result.log += 'code check raised exception:\n' 99 state.result.log += code + '\n' 100 state.result.log += str(e) 101 return False 102 if not state.g['success']: 103 state.result.log += 'code check failed:\n' 104 state.result.log += code + '\n' 105 return False 106 return True 107 108class StringStream: 109 class Pos: 110 def __init__(self): 111 self.line = 1 112 self.column = 1 113 114 def __init__(self, data, name): 115 self.name = name 116 self.data = data 117 self.offset = 0 118 self.pos = StringStream.Pos() 119 120 def reset(self): 121 self.offset = 0 122 self.pos = StringStream.Pos() 123 124 def peek(self, num=1): 125 return self.data[self.offset:self.offset+num] 126 127 def peek_test(self, chars): 128 c = self.peek(1) 129 return c != '' and c in chars 130 131 def read(self, num=4294967296): 132 res = self.peek(num) 133 self.offset += len(res) 134 for c in res: 135 if c == '\n': 136 self.pos.line += 1 137 self.pos.column = 1 138 else: 139 self.pos.column += 1 140 return res 141 142 def get_line(self, num): 143 return self.data.split('\n')[num - 1].rstrip() 144 145 def skip_line(self): 146 while self.peek(1) not in ['\n', '']: 147 self.read(1) 148 self.read(1) 149 150 def skip_whitespace(self, inc_line): 151 chars = [' ', '\t'] + (['\n'] if inc_line else []) 152 while self.peek(1) in chars: 153 self.read(1) 154 155 def get_number(self): 156 num = '' 157 while self.peek() in string.digits: 158 num += self.read(1) 159 return num 160 161 def check_identifier(self): 162 return self.peek_test(string.ascii_letters + '_') 163 164 def get_identifier(self): 165 res = '' 166 if self.check_identifier(): 167 while self.peek_test(string.ascii_letters + string.digits + '_'): 168 res += self.read(1) 169 return res 170 171def format_error_lines(at, line_num, column_num, ctx, line): 172 pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx) 173 return [pred + line + '"', 174 '-' * (column_num - 1 + len(pred)) + '^'] 175 176class MatchResult: 177 def __init__(self, pattern): 178 self.success = True 179 self.func_res = None 180 self.pattern = pattern 181 self.pattern_pos = StringStream.Pos() 182 self.output_pos = StringStream.Pos() 183 self.fail_message = '' 184 185 def set_pos(self, pattern, output): 186 self.pattern_pos.line = pattern.pos.line 187 self.pattern_pos.column = pattern.pos.column 188 self.output_pos.line = output.pos.line 189 self.output_pos.column = output.pos.column 190 191 def fail(self, msg): 192 self.success = False 193 self.fail_message = msg 194 195 def format_pattern_pos(self): 196 pat_pos = self.pattern_pos 197 pat_line = self.pattern.get_line(pat_pos.line) 198 res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line) 199 func_res = self.func_res 200 while func_res: 201 pat_pos = func_res.pattern_pos 202 pat_line = func_res.pattern.get_line(pat_pos.line) 203 res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line) 204 func_res = func_res.func_res 205 return '\n'.join(res) 206 207def do_match(g, pattern, output, skip_lines, in_func=False): 208 assert(not in_func or not skip_lines) 209 210 if not in_func: 211 output.skip_whitespace(False) 212 pattern.skip_whitespace(False) 213 214 old_g = copy.copy(g) 215 old_g_keys = list(g.keys()) 216 res = MatchResult(pattern) 217 escape = False 218 while True: 219 res.set_pos(pattern, output) 220 221 c = pattern.read(1) 222 fail = False 223 if c == '': 224 break 225 elif output.peek() == '': 226 res.fail('unexpected end of output') 227 elif c == '\\': 228 escape = True 229 continue 230 elif c == '\n': 231 old_line = output.pos.line 232 output.skip_whitespace(True) 233 if output.pos.line == old_line: 234 res.fail('expected newline in output') 235 elif not escape and c == '#': 236 num = output.get_number() 237 if num == '': 238 res.fail('expected number in output') 239 elif pattern.check_identifier(): 240 name = pattern.get_identifier() 241 if name in g and int(num) != g[name]: 242 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) 243 elif name != '_': 244 g[name] = int(num) 245 elif not escape and c == '$': 246 name = pattern.get_identifier() 247 248 val = '' 249 while not output.peek_test(string.whitespace): 250 val += output.read(1) 251 252 if name in g and val != g[name]: 253 res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name])) 254 elif name != '_': 255 g[name] = val 256 elif not escape and c == '%' and pattern.check_identifier(): 257 if output.read(1) != '%': 258 res.fail('expected \'%\' in output') 259 else: 260 num = output.get_number() 261 if num == '': 262 res.fail('expected number in output') 263 else: 264 name = pattern.get_identifier() 265 if name in g and int(num) != g[name]: 266 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) 267 elif name != '_': 268 g[name] = int(num) 269 elif not escape and c == '@' and pattern.check_identifier(): 270 name = pattern.get_identifier() 271 args = '' 272 if pattern.peek_test('('): 273 pattern.read(1) 274 while pattern.peek() not in ['', ')']: 275 args += pattern.read(1) 276 assert(pattern.read(1) == ')') 277 func_res = g['funcs'][name](args) 278 match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True) 279 if not match_res.success: 280 res.func_res = match_res 281 res.output_pos = match_res.output_pos 282 res.fail(match_res.fail_message) 283 elif not escape and c == ' ': 284 while pattern.peek_test(' '): 285 pattern.read(1) 286 287 read_whitespace = False 288 while output.peek_test(' \t'): 289 output.read(1) 290 read_whitespace = True 291 if not read_whitespace: 292 res.fail('expected whitespace in output, got %r' % (output.peek(1))) 293 else: 294 outc = output.peek(1) 295 if outc != c: 296 res.fail('expected %r in output, got %r' % (c, outc)) 297 else: 298 output.read(1) 299 if not res.success: 300 if skip_lines and output.peek() != '': 301 g.clear() 302 g.update(old_g) 303 res.success = True 304 output.skip_line() 305 pattern.reset() 306 output.skip_whitespace(False) 307 pattern.skip_whitespace(False) 308 else: 309 return res 310 311 escape = False 312 313 if not in_func: 314 while output.peek() in [' ', '\t']: 315 output.read(1) 316 317 if output.read(1) not in ['', '\n']: 318 res.fail('expected end of output') 319 return res 320 321 return res 322 323class PatternCheck(Check): 324 def __init__(self, data, search, position): 325 Check.__init__(self, data) 326 self.search = search 327 self.position = position 328 329 def run(self, state): 330 pattern_stream = StringStream(self.data.rstrip(), 'pattern') 331 res = do_match(state.g, pattern_stream, state.g['output'], self.search) 332 if not res.success: 333 state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message) 334 state.result.log += res.format_pattern_pos() + '\n\n' 335 if not self.search: 336 out_line = state.g['output'].get_line(res.output_pos.line) 337 state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line)) 338 else: 339 state.result.log += 'output was:\n' 340 state.result.log += state.g['output'].data.rstrip() + '\n' 341 return False 342 return True 343 344class CheckState: 345 def __init__(self, result, variant, checks, output): 346 self.result = result 347 self.variant = variant 348 self.checks = checks 349 350 self.checks.insert(0, CodeCheck(initial_code)) 351 self.insert_queue = [] 352 353 self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue, 354 'variant': variant, 'log': '', 'output': StringStream(output, 'output'), 355 'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck} 356 357class TestResult: 358 def __init__(self, expected): 359 self.result = '' 360 self.expected = expected 361 self.log = '' 362 363def check_output(result, variant, checks, output): 364 state = CheckState(result, variant, checks, output) 365 366 while len(state.checks): 367 check = state.checks.pop(0) 368 if not check.run(state): 369 result.result = 'failed' 370 return 371 372 for check in state.insert_queue[::-1]: 373 state.checks.insert(0, check) 374 state.insert_queue.clear() 375 376 result.result = 'passed' 377 return 378 379def parse_check(variant, line, checks, pos): 380 if line.startswith(';'): 381 line = line[1:] 382 if len(checks) and isinstance(checks[-1], CodeCheck): 383 checks[-1].data += '\n' + line 384 else: 385 checks.append(CodeCheck(line)) 386 elif line.startswith('!'): 387 checks.append(PatternCheck(line[1:], False, pos)) 388 elif line.startswith('>>'): 389 checks.append(PatternCheck(line[2:], True, pos)) 390 elif line.startswith('~'): 391 end = len(line) 392 start = len(line) 393 for c in [';', '!', '>>']: 394 if line.find(c) != -1 and line.find(c) < end: 395 end = line.find(c) 396 if end != len(line): 397 match = re.match(line[1:end], variant) 398 if match and match.end() == len(variant): 399 parse_check(variant, line[end:], checks, pos) 400 401def parse_test_source(test_name, variant, fname): 402 in_test = False 403 test = [] 404 expected_result = 'passed' 405 line_num = 1 406 for line in open(fname, 'r').readlines(): 407 if line.startswith('BEGIN_TEST(%s)' % test_name): 408 in_test = True 409 elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name): 410 in_test = True 411 expected_result = 'todo' 412 elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name): 413 in_test = True 414 expected_result = 'failed' 415 elif line.startswith('END_TEST'): 416 in_test = False 417 elif in_test: 418 test.append((line_num, line.strip())) 419 line_num += 1 420 421 checks = [] 422 for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]: 423 parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1])) 424 425 return checks, expected_result 426 427def parse_and_check_test(test_name, variant, test_file, output, current_result): 428 checks, expected = parse_test_source(test_name, variant, test_file) 429 430 result = TestResult(expected) 431 if len(checks) == 0: 432 result.result = 'empty' 433 result.log = 'no checks found' 434 elif current_result != None: 435 result.result, result.log = current_result 436 else: 437 check_output(result, variant, checks, output) 438 if result.result == 'failed' and expected == 'todo': 439 result.result = 'todo' 440 441 return result 442 443def print_results(results, output, expected): 444 results = {name: result for name, result in results.items() if result.result == output} 445 results = {name: result for name, result in results.items() if (result.result == result.expected) == expected} 446 447 if not results: 448 return 0 449 450 print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected')) 451 for test, result in results.items(): 452 color = '' if expected else set_red 453 print(' %s%s%s' % (color, test, set_normal)) 454 if result.log.strip() != '': 455 for line in result.log.rstrip().split('\n'): 456 print(' ' + line.rstrip()) 457 print('') 458 459 return len(results) 460 461def get_cstr(fp): 462 res = b'' 463 while True: 464 c = fp.read(1) 465 if c == b'\x00': 466 return res.decode('utf-8') 467 else: 468 res += c 469 470if __name__ == "__main__": 471 results = {} 472 473 stdin = sys.stdin.buffer 474 while True: 475 packet_type = stdin.read(4) 476 if packet_type == b'': 477 break; 478 479 test_name = get_cstr(stdin) 480 test_variant = get_cstr(stdin) 481 if test_variant != '': 482 full_name = test_name + '/' + test_variant 483 else: 484 full_name = test_name 485 486 test_source_file = get_cstr(stdin) 487 current_result = None 488 if ord(stdin.read(1)): 489 current_result = (get_cstr(stdin), get_cstr(stdin)) 490 code_size = struct.unpack("=L", stdin.read(4))[0] 491 code = stdin.read(code_size).decode('utf-8') 492 493 results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result) 494 495 result_types = ['passed', 'failed', 'todo', 'empty'] 496 num_expected = 0 497 num_unexpected = 0 498 for t in result_types: 499 num_expected += print_results(results, t, True) 500 for t in result_types: 501 num_unexpected += print_results(results, t, False) 502 num_expected_skipped = print_results(results, 'skipped', True) 503 num_unexpected_skipped = print_results(results, 'skipped', False) 504 505 num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped 506 color = set_red if num_unexpected else set_green 507 print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal)) 508 if num_unexpected_skipped: 509 print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal)) 510 511 if num_unexpected: 512 sys.exit(1) 513