1# 2# Copyright (c) 2020 Valve Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22import re 23import sys 24import os.path 25import struct 26import string 27import copy 28from math import floor 29 30if os.isatty(sys.stdout.fileno()): 31 set_red = "\033[31m" 32 set_green = "\033[1;32m" 33 set_normal = "\033[0m" 34else: 35 set_red = '' 36 set_green = '' 37 set_normal = '' 38 39initial_code = ''' 40import re 41 42def insert_code(code): 43 insert_queue.append(CodeCheck(code, current_position)) 44 45def insert_pattern(pattern): 46 insert_queue.append(PatternCheck(pattern, False, current_position)) 47 48def vector_gpr(prefix, name, size, align): 49 insert_code(f'{name} = {name}0') 50 for i in range(size): 51 insert_code(f'{name}{i} = {name}0 + {i}') 52 insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}') 53 insert_code(f'success = {name}0 % {align} == 0') 54 return f'{prefix}[#{name}0:#{name}{size - 1}]' 55 56def sgpr_vector(name, size, align): 57 return vector_gpr('s', name, size, align) 58 59funcs.update({ 60 's64': lambda name: vector_gpr('s', name, 2, 2), 61 's96': lambda name: vector_gpr('s', name, 3, 2), 62 's128': lambda name: vector_gpr('s', name, 4, 4), 63 's256': lambda name: vector_gpr('s', name, 8, 4), 64 's512': lambda name: vector_gpr('s', name, 16, 4), 65}) 66for i in range(2, 14): 67 funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1) 68 69def _match_func(names): 70 for name in names.split(' '): 71 insert_code(f'funcs["{name}"] = lambda _: {name}') 72 return ' '.join(f'${name}' for name in names.split(' ')) 73 74funcs['match_func'] = _match_func 75 76def search_re(pattern): 77 global success 78 success = re.search(pattern, output.read_line()) != None and success 79 80''' 81 82class Check: 83 def __init__(self, data, position): 84 self.data = data.rstrip() 85 self.position = position 86 87 def run(self, state): 88 pass 89 90class CodeCheck(Check): 91 def run(self, state): 92 indent = 0 93 first_line = [l for l in self.data.split('\n') if l.strip() != ''][0] 94 indent_amount = len(first_line) - len(first_line.lstrip()) 95 indent = first_line[:indent_amount] 96 new_lines = [] 97 for line in self.data.split('\n'): 98 if line.strip() == '': 99 new_lines.append('') 100 continue 101 if line[:indent_amount] != indent: 102 state.result.log += 'unexpected indent in code check:\n' 103 state.result.log += self.data + '\n' 104 return False 105 new_lines.append(line[indent_amount:]) 106 code = '\n'.join(new_lines) 107 108 try: 109 exec(code, state.g) 110 state.result.log += state.g['log'] 111 state.g['log'] = '' 112 except BaseException as e: 113 state.result.log += 'code check at %s raised exception:\n' % self.position 114 state.result.log += code + '\n' 115 state.result.log += str(e) 116 return False 117 if not state.g['success']: 118 state.result.log += 'code check at %s failed:\n' % self.position 119 state.result.log += code + '\n' 120 return False 121 return True 122 123class StringStream: 124 class Pos: 125 def __init__(self): 126 self.line = 1 127 self.column = 1 128 129 def __init__(self, data, name): 130 self.name = name 131 self.data = data 132 self.offset = 0 133 self.pos = StringStream.Pos() 134 135 def reset(self): 136 self.offset = 0 137 self.pos = StringStream.Pos() 138 139 def peek(self, num=1): 140 return self.data[self.offset:self.offset+num] 141 142 def peek_test(self, chars): 143 c = self.peek(1) 144 return c != '' and c in chars 145 146 def read(self, num=4294967296): 147 res = self.peek(num) 148 self.offset += len(res) 149 for c in res: 150 if c == '\n': 151 self.pos.line += 1 152 self.pos.column = 1 153 else: 154 self.pos.column += 1 155 return res 156 157 def get_line(self, num): 158 return self.data.split('\n')[num - 1].rstrip() 159 160 def read_line(self): 161 line = '' 162 while self.peek(1) not in ['\n', '']: 163 line += self.read(1) 164 self.read(1) 165 return line 166 167 def skip_whitespace(self, inc_line): 168 chars = [' ', '\t'] + (['\n'] if inc_line else []) 169 while self.peek(1) in chars: 170 self.read(1) 171 172 def get_number(self): 173 num = '' 174 while self.peek() in string.digits: 175 num += self.read(1) 176 return num 177 178 def check_identifier(self): 179 return self.peek_test(string.ascii_letters + '_') 180 181 def get_identifier(self): 182 res = '' 183 if self.check_identifier(): 184 while self.peek_test(string.ascii_letters + string.digits + '_'): 185 res += self.read(1) 186 return res 187 188def format_error_lines(at, line_num, column_num, ctx, line): 189 pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx) 190 return [pred + line + '"', 191 '-' * (column_num - 1 + len(pred)) + '^'] 192 193class MatchResult: 194 def __init__(self, pattern): 195 self.success = True 196 self.func_res = None 197 self.pattern = pattern 198 self.pattern_pos = StringStream.Pos() 199 self.output_pos = StringStream.Pos() 200 self.fail_message = '' 201 202 def set_pos(self, pattern, output): 203 self.pattern_pos.line = pattern.pos.line 204 self.pattern_pos.column = pattern.pos.column 205 self.output_pos.line = output.pos.line 206 self.output_pos.column = output.pos.column 207 208 def fail(self, msg): 209 self.success = False 210 self.fail_message = msg 211 212 def format_pattern_pos(self): 213 pat_pos = self.pattern_pos 214 pat_line = self.pattern.get_line(pat_pos.line) 215 res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line) 216 func_res = self.func_res 217 while func_res: 218 pat_pos = func_res.pattern_pos 219 pat_line = func_res.pattern.get_line(pat_pos.line) 220 res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line) 221 func_res = func_res.func_res 222 return '\n'.join(res) 223 224def do_match(g, pattern, output, skip_lines, in_func=False): 225 assert(not in_func or not skip_lines) 226 227 if not in_func: 228 output.skip_whitespace(False) 229 pattern.skip_whitespace(False) 230 231 old_g = copy.copy(g) 232 old_g_keys = list(g.keys()) 233 res = MatchResult(pattern) 234 escape = False 235 while True: 236 res.set_pos(pattern, output) 237 238 c = pattern.read(1) 239 fail = False 240 if c == '': 241 break 242 elif output.peek() == '': 243 res.fail('unexpected end of output') 244 elif c == '\\': 245 escape = True 246 continue 247 elif c == '\n': 248 old_line = output.pos.line 249 output.skip_whitespace(True) 250 if output.pos.line == old_line: 251 res.fail('expected newline in output') 252 elif not escape and c == '#': 253 num = output.get_number() 254 if num == '': 255 res.fail('expected number in output') 256 elif pattern.check_identifier(): 257 name = pattern.get_identifier() 258 if name in g and int(num) != g[name]: 259 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) 260 elif name != '_': 261 g[name] = int(num) 262 elif not escape and c == '$': 263 name = pattern.get_identifier() 264 265 val = '' 266 while not output.peek_test(string.whitespace): 267 val += output.read(1) 268 269 if name in g and val != g[name]: 270 res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name])) 271 elif name != '_': 272 g[name] = val 273 elif not escape and c == '%' and pattern.check_identifier(): 274 if output.read(1) != '%': 275 res.fail('expected \'%\' in output') 276 else: 277 num = output.get_number() 278 if num == '': 279 res.fail('expected number in output') 280 else: 281 name = pattern.get_identifier() 282 if name in g and int(num) != g[name]: 283 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) 284 elif name != '_': 285 g[name] = int(num) 286 elif not escape and c == '@' and pattern.check_identifier(): 287 name = pattern.get_identifier() 288 args = '' 289 if pattern.peek_test('('): 290 pattern.read(1) 291 while pattern.peek() not in ['', ')']: 292 args += pattern.read(1) 293 assert(pattern.read(1) == ')') 294 func_res = g['funcs'][name](args) 295 match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True) 296 if not match_res.success: 297 res.func_res = match_res 298 res.output_pos = match_res.output_pos 299 res.fail(match_res.fail_message) 300 elif not escape and c == ' ': 301 while pattern.peek_test(' '): 302 pattern.read(1) 303 304 read_whitespace = False 305 while output.peek_test(' \t'): 306 output.read(1) 307 read_whitespace = True 308 if not read_whitespace: 309 res.fail('expected whitespace in output, got %r' % (output.peek(1))) 310 else: 311 outc = output.peek(1) 312 if outc != c: 313 res.fail('expected %r in output, got %r' % (c, outc)) 314 else: 315 output.read(1) 316 if not res.success: 317 if skip_lines and output.peek() != '': 318 g.clear() 319 g.update(old_g) 320 res.success = True 321 output.read_line() 322 pattern.reset() 323 output.skip_whitespace(False) 324 pattern.skip_whitespace(False) 325 else: 326 return res 327 328 escape = False 329 330 if not in_func: 331 while output.peek() in [' ', '\t']: 332 output.read(1) 333 334 if output.read(1) not in ['', '\n']: 335 res.fail('expected end of output') 336 return res 337 338 return res 339 340class PatternCheck(Check): 341 def __init__(self, data, search, position): 342 Check.__init__(self, data, position) 343 self.search = search 344 345 def run(self, state): 346 pattern_stream = StringStream(self.data.rstrip(), 'pattern') 347 res = do_match(state.g, pattern_stream, state.g['output'], self.search) 348 if not res.success: 349 state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message) 350 state.result.log += res.format_pattern_pos() + '\n\n' 351 if not self.search: 352 out_line = state.g['output'].get_line(res.output_pos.line) 353 state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line)) 354 else: 355 state.result.log += 'output was:\n' 356 state.result.log += state.g['output'].data.rstrip() + '\n' 357 return False 358 return True 359 360class CheckState: 361 def __init__(self, result, variant, checks, output): 362 self.result = result 363 self.variant = variant 364 self.checks = checks 365 366 self.checks.insert(0, CodeCheck(initial_code, None)) 367 self.insert_queue = [] 368 369 self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue, 370 'variant': variant, 'log': '', 'output': StringStream(output, 'output'), 371 'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck, 372 'current_position': ''} 373 374class TestResult: 375 def __init__(self, expected): 376 self.result = '' 377 self.expected = expected 378 self.log = '' 379 380def check_output(result, variant, checks, output): 381 state = CheckState(result, variant, checks, output) 382 383 while len(state.checks): 384 check = state.checks.pop(0) 385 state.current_position = check.position 386 if not check.run(state): 387 result.result = 'failed' 388 return 389 390 for check in state.insert_queue[::-1]: 391 state.checks.insert(0, check) 392 state.insert_queue.clear() 393 394 result.result = 'passed' 395 return 396 397def parse_check(variant, line, checks, pos): 398 if line.startswith(';'): 399 line = line[1:] 400 if len(checks) and isinstance(checks[-1], CodeCheck): 401 checks[-1].data += '\n' + line 402 else: 403 checks.append(CodeCheck(line, pos)) 404 elif line.startswith('!'): 405 checks.append(PatternCheck(line[1:], False, pos)) 406 elif line.startswith('>>'): 407 checks.append(PatternCheck(line[2:], True, pos)) 408 elif line.startswith('~'): 409 end = len(line) 410 start = len(line) 411 for c in [';', '!', '>>']: 412 if line.find(c) != -1 and line.find(c) < end: 413 end = line.find(c) 414 if end != len(line): 415 match = re.match(line[1:end], variant) 416 if match and match.end() == len(variant): 417 parse_check(variant, line[end:], checks, pos) 418 419def parse_test_source(test_name, variant, fname): 420 in_test = False 421 test = [] 422 expected_result = 'passed' 423 line_num = 1 424 for line in open(fname, 'r').readlines(): 425 if line.startswith('BEGIN_TEST(%s)' % test_name): 426 in_test = True 427 elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name): 428 in_test = True 429 expected_result = 'todo' 430 elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name): 431 in_test = True 432 expected_result = 'failed' 433 elif line.startswith('END_TEST'): 434 in_test = False 435 elif in_test: 436 test.append((line_num, line.strip())) 437 line_num += 1 438 439 checks = [] 440 for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]: 441 parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1])) 442 443 return checks, expected_result 444 445def parse_and_check_test(test_name, variant, test_file, output, current_result): 446 checks, expected = parse_test_source(test_name, variant, test_file) 447 448 result = TestResult(expected) 449 if len(checks) == 0: 450 result.result = 'empty' 451 result.log = 'no checks found' 452 elif current_result != None: 453 result.result, result.log = current_result 454 else: 455 check_output(result, variant, checks, output) 456 if result.result == 'failed' and expected == 'todo': 457 result.result = 'todo' 458 459 return result 460 461def print_results(results, output, expected): 462 results = {name: result for name, result in results.items() if result.result == output} 463 results = {name: result for name, result in results.items() if (result.result == result.expected) == expected} 464 465 if not results: 466 return 0 467 468 print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected')) 469 for test, result in results.items(): 470 color = '' if expected else set_red 471 print(' %s%s%s' % (color, test, set_normal)) 472 if result.log.strip() != '': 473 for line in result.log.rstrip().split('\n'): 474 print(' ' + line.rstrip()) 475 print('') 476 477 return len(results) 478 479def get_cstr(fp): 480 res = b'' 481 while True: 482 c = fp.read(1) 483 if c == b'\x00': 484 return res.decode('utf-8') 485 else: 486 res += c 487 488if __name__ == "__main__": 489 results = {} 490 491 stdin = sys.stdin.buffer 492 while True: 493 packet_type = stdin.read(4) 494 if packet_type == b'': 495 break; 496 497 test_name = get_cstr(stdin) 498 test_variant = get_cstr(stdin) 499 if test_variant != '': 500 full_name = test_name + '/' + test_variant 501 else: 502 full_name = test_name 503 504 test_source_file = get_cstr(stdin) 505 current_result = None 506 if ord(stdin.read(1)): 507 current_result = (get_cstr(stdin), get_cstr(stdin)) 508 code_size = struct.unpack("=L", stdin.read(4))[0] 509 code = stdin.read(code_size).decode('utf-8') 510 511 results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result) 512 513 result_types = ['passed', 'failed', 'todo', 'empty'] 514 num_expected = 0 515 num_unexpected = 0 516 for t in result_types: 517 num_expected += print_results(results, t, True) 518 for t in result_types: 519 num_unexpected += print_results(results, t, False) 520 num_expected_skipped = print_results(results, 'skipped', True) 521 num_unexpected_skipped = print_results(results, 'skipped', False) 522 523 num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped 524 color = set_red if num_unexpected else set_green 525 print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal)) 526 if num_unexpected_skipped: 527 print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal)) 528 529 if num_unexpected: 530 sys.exit(1) 531