• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (c) 2020 Valve Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22import re
23import sys
24import os.path
25import struct
26import string
27import copy
28from math import floor
29
30if os.isatty(sys.stdout.fileno()):
31    set_red = "\033[31m"
32    set_green = "\033[1;32m"
33    set_normal = "\033[0m"
34else:
35    set_red = ''
36    set_green = ''
37    set_normal = ''
38
39initial_code = '''
40import re
41
42def insert_code(code):
43    insert_queue.append(CodeCheck(code, current_position))
44
45def insert_pattern(pattern):
46    insert_queue.append(PatternCheck(pattern, False, current_position))
47
48def vector_gpr(prefix, name, size, align):
49    insert_code(f'{name} = {name}0')
50    for i in range(size):
51        insert_code(f'{name}{i} = {name}0 + {i}')
52    insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}')
53    insert_code(f'success = {name}0 % {align} == 0')
54    return f'{prefix}[#{name}0:#{name}{size - 1}]'
55
56def sgpr_vector(name, size, align):
57    return vector_gpr('s', name, size, align)
58
59funcs.update({
60    's64': lambda name: vector_gpr('s', name, 2, 2),
61    's96': lambda name: vector_gpr('s', name, 3, 2),
62    's128': lambda name: vector_gpr('s', name, 4, 4),
63    's256': lambda name: vector_gpr('s', name, 8, 4),
64    's512': lambda name: vector_gpr('s', name, 16, 4),
65})
66for i in range(2, 14):
67    funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1)
68
69def _match_func(names):
70    for name in names.split(' '):
71        insert_code(f'funcs["{name}"] = lambda _: {name}')
72    return ' '.join(f'${name}' for name in names.split(' '))
73
74funcs['match_func'] = _match_func
75
76def search_re(pattern):
77    global success
78    success = re.search(pattern, output.read_line()) != None and success
79
80'''
81
82class Check:
83    def __init__(self, data, position):
84        self.data = data.rstrip()
85        self.position = position
86
87    def run(self, state):
88        pass
89
90class CodeCheck(Check):
91    def run(self, state):
92        indent = 0
93        first_line = [l for l in self.data.split('\n') if l.strip() != ''][0]
94        indent_amount = len(first_line) - len(first_line.lstrip())
95        indent = first_line[:indent_amount]
96        new_lines = []
97        for line in self.data.split('\n'):
98            if line.strip() == '':
99                new_lines.append('')
100                continue
101            if line[:indent_amount] != indent:
102                state.result.log += 'unexpected indent in code check:\n'
103                state.result.log += self.data + '\n'
104                return False
105            new_lines.append(line[indent_amount:])
106        code = '\n'.join(new_lines)
107
108        try:
109            exec(code, state.g)
110            state.result.log += state.g['log']
111            state.g['log'] = ''
112        except BaseException as e:
113            state.result.log += 'code check at %s raised exception:\n' % self.position
114            state.result.log += code + '\n'
115            state.result.log += str(e)
116            return False
117        if not state.g['success']:
118            state.result.log += 'code check at %s failed:\n' % self.position
119            state.result.log += code + '\n'
120            return False
121        return True
122
123class StringStream:
124    class Pos:
125        def __init__(self):
126            self.line = 1
127            self.column = 1
128
129    def __init__(self, data, name):
130        self.name = name
131        self.data = data
132        self.offset = 0
133        self.pos = StringStream.Pos()
134
135    def reset(self):
136        self.offset = 0
137        self.pos = StringStream.Pos()
138
139    def peek(self, num=1):
140        return self.data[self.offset:self.offset+num]
141
142    def peek_test(self, chars):
143        c = self.peek(1)
144        return c != '' and c in chars
145
146    def read(self, num=4294967296):
147        res = self.peek(num)
148        self.offset += len(res)
149        for c in res:
150            if c == '\n':
151                self.pos.line += 1
152                self.pos.column = 1
153            else:
154                self.pos.column += 1
155        return res
156
157    def get_line(self, num):
158        return self.data.split('\n')[num - 1].rstrip()
159
160    def read_line(self):
161        line = ''
162        while self.peek(1) not in ['\n', '']:
163            line += self.read(1)
164        self.read(1)
165        return line
166
167    def skip_whitespace(self, inc_line):
168        chars = [' ', '\t'] + (['\n'] if inc_line else [])
169        while self.peek(1) in chars:
170            self.read(1)
171
172    def get_number(self):
173        num = ''
174        while self.peek() in string.digits:
175            num += self.read(1)
176        return num
177
178    def check_identifier(self):
179        return self.peek_test(string.ascii_letters + '_')
180
181    def get_identifier(self):
182        res = ''
183        if self.check_identifier():
184            while self.peek_test(string.ascii_letters + string.digits + '_'):
185                res += self.read(1)
186        return res
187
188def format_error_lines(at, line_num, column_num, ctx, line):
189    pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx)
190    return [pred + line + '"',
191            '-' * (column_num - 1 + len(pred)) + '^']
192
193class MatchResult:
194    def __init__(self, pattern):
195        self.success = True
196        self.func_res = None
197        self.pattern = pattern
198        self.pattern_pos = StringStream.Pos()
199        self.output_pos = StringStream.Pos()
200        self.fail_message = ''
201
202    def set_pos(self, pattern, output):
203        self.pattern_pos.line = pattern.pos.line
204        self.pattern_pos.column = pattern.pos.column
205        self.output_pos.line = output.pos.line
206        self.output_pos.column = output.pos.column
207
208    def fail(self, msg):
209        self.success = False
210        self.fail_message = msg
211
212    def format_pattern_pos(self):
213        pat_pos = self.pattern_pos
214        pat_line = self.pattern.get_line(pat_pos.line)
215        res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line)
216        func_res = self.func_res
217        while func_res:
218            pat_pos = func_res.pattern_pos
219            pat_line = func_res.pattern.get_line(pat_pos.line)
220            res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line)
221            func_res = func_res.func_res
222        return '\n'.join(res)
223
224def do_match(g, pattern, output, skip_lines, in_func=False):
225    assert(not in_func or not skip_lines)
226
227    if not in_func:
228        output.skip_whitespace(False)
229    pattern.skip_whitespace(False)
230
231    old_g = copy.copy(g)
232    old_g_keys = list(g.keys())
233    res = MatchResult(pattern)
234    escape = False
235    while True:
236        res.set_pos(pattern, output)
237
238        c = pattern.read(1)
239        fail = False
240        if c == '':
241            break
242        elif output.peek() == '':
243            res.fail('unexpected end of output')
244        elif c == '\\':
245            escape = True
246            continue
247        elif c == '\n':
248            old_line = output.pos.line
249            output.skip_whitespace(True)
250            if output.pos.line == old_line:
251                res.fail('expected newline in output')
252        elif not escape and c == '#':
253            num = output.get_number()
254            if num == '':
255                res.fail('expected number in output')
256            elif pattern.check_identifier():
257                name = pattern.get_identifier()
258                if name in g and int(num) != g[name]:
259                    res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
260                elif name != '_':
261                    g[name] = int(num)
262        elif not escape and c == '$':
263            name = pattern.get_identifier()
264
265            val = ''
266            while not output.peek_test(string.whitespace):
267                val += output.read(1)
268
269            if name in g and val != g[name]:
270                res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name]))
271            elif name != '_':
272                g[name] = val
273        elif not escape and c == '%' and pattern.check_identifier():
274            if output.read(1) != '%':
275                res.fail('expected \'%\' in output')
276            else:
277                num = output.get_number()
278                if num == '':
279                    res.fail('expected number in output')
280                else:
281                    name = pattern.get_identifier()
282                    if name in g and int(num) != g[name]:
283                        res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
284                    elif name != '_':
285                        g[name] = int(num)
286        elif not escape and c == '@' and pattern.check_identifier():
287            name = pattern.get_identifier()
288            args = ''
289            if pattern.peek_test('('):
290                pattern.read(1)
291                while pattern.peek() not in ['', ')']:
292                    args += pattern.read(1)
293                assert(pattern.read(1) == ')')
294            func_res = g['funcs'][name](args)
295            match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True)
296            if not match_res.success:
297                res.func_res = match_res
298                res.output_pos = match_res.output_pos
299                res.fail(match_res.fail_message)
300        elif not escape and c == ' ':
301            while pattern.peek_test(' '):
302                pattern.read(1)
303
304            read_whitespace = False
305            while output.peek_test(' \t'):
306                output.read(1)
307                read_whitespace = True
308            if not read_whitespace:
309                res.fail('expected whitespace in output, got %r' % (output.peek(1)))
310        else:
311            outc = output.peek(1)
312            if outc != c:
313                res.fail('expected %r in output, got %r' % (c, outc))
314            else:
315                output.read(1)
316        if not res.success:
317            if skip_lines and output.peek() != '':
318                g.clear()
319                g.update(old_g)
320                res.success = True
321                output.read_line()
322                pattern.reset()
323                output.skip_whitespace(False)
324                pattern.skip_whitespace(False)
325            else:
326                return res
327
328        escape = False
329
330    if not in_func:
331        while output.peek() in [' ', '\t']:
332            output.read(1)
333
334        if output.read(1) not in ['', '\n']:
335            res.fail('expected end of output')
336            return res
337
338    return res
339
340class PatternCheck(Check):
341    def __init__(self, data, search, position):
342        Check.__init__(self, data, position)
343        self.search = search
344
345    def run(self, state):
346        pattern_stream = StringStream(self.data.rstrip(), 'pattern')
347        res = do_match(state.g, pattern_stream, state.g['output'], self.search)
348        if not res.success:
349            state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message)
350            state.result.log += res.format_pattern_pos() + '\n\n'
351            if not self.search:
352                out_line = state.g['output'].get_line(res.output_pos.line)
353                state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line))
354            else:
355                state.result.log += 'output was:\n'
356                state.result.log += state.g['output'].data.rstrip() + '\n'
357            return False
358        return True
359
360class CheckState:
361    def __init__(self, result, variant, checks, output):
362        self.result = result
363        self.variant = variant
364        self.checks = checks
365
366        self.checks.insert(0, CodeCheck(initial_code, None))
367        self.insert_queue = []
368
369        self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue,
370                  'variant': variant, 'log': '', 'output': StringStream(output, 'output'),
371                  'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck,
372                  'current_position': ''}
373
374class TestResult:
375    def __init__(self, expected):
376        self.result = ''
377        self.expected = expected
378        self.log = ''
379
380def check_output(result, variant, checks, output):
381    state = CheckState(result, variant, checks, output)
382
383    while len(state.checks):
384        check = state.checks.pop(0)
385        state.current_position = check.position
386        if not check.run(state):
387            result.result = 'failed'
388            return
389
390        for check in state.insert_queue[::-1]:
391            state.checks.insert(0, check)
392        state.insert_queue.clear()
393
394    result.result = 'passed'
395    return
396
397def parse_check(variant, line, checks, pos):
398    if line.startswith(';'):
399        line = line[1:]
400        if len(checks) and isinstance(checks[-1], CodeCheck):
401            checks[-1].data += '\n' + line
402        else:
403            checks.append(CodeCheck(line, pos))
404    elif line.startswith('!'):
405        checks.append(PatternCheck(line[1:], False, pos))
406    elif line.startswith('>>'):
407        checks.append(PatternCheck(line[2:], True, pos))
408    elif line.startswith('~'):
409        end = len(line)
410        start = len(line)
411        for c in [';', '!', '>>']:
412            if line.find(c) != -1 and line.find(c) < end:
413                end = line.find(c)
414        if end != len(line):
415            match = re.match(line[1:end], variant)
416            if match and match.end() == len(variant):
417                parse_check(variant, line[end:], checks, pos)
418
419def parse_test_source(test_name, variant, fname):
420    in_test = False
421    test = []
422    expected_result = 'passed'
423    line_num = 1
424    for line in open(fname, 'r').readlines():
425        if line.startswith('BEGIN_TEST(%s)' % test_name):
426            in_test = True
427        elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name):
428            in_test = True
429            expected_result = 'todo'
430        elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name):
431            in_test = True
432            expected_result = 'failed'
433        elif line.startswith('END_TEST'):
434            in_test = False
435        elif in_test:
436            test.append((line_num, line.strip()))
437        line_num += 1
438
439    checks = []
440    for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]:
441         parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1]))
442
443    return checks, expected_result
444
445def parse_and_check_test(test_name, variant, test_file, output, current_result):
446    checks, expected = parse_test_source(test_name, variant, test_file)
447
448    result = TestResult(expected)
449    if len(checks) == 0:
450        result.result = 'empty'
451        result.log = 'no checks found'
452    elif current_result != None:
453        result.result, result.log = current_result
454    else:
455        check_output(result, variant, checks, output)
456        if result.result == 'failed' and expected == 'todo':
457            result.result = 'todo'
458
459    return result
460
461def print_results(results, output, expected):
462    results = {name: result for name, result in results.items() if result.result == output}
463    results = {name: result for name, result in results.items() if (result.result == result.expected) == expected}
464
465    if not results:
466        return 0
467
468    print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected'))
469    for test, result in results.items():
470        color = '' if expected else set_red
471        print('   %s%s%s' % (color, test, set_normal))
472        if result.log.strip() != '':
473            for line in result.log.rstrip().split('\n'):
474                print('      ' + line.rstrip())
475    print('')
476
477    return len(results)
478
479def get_cstr(fp):
480    res = b''
481    while True:
482        c = fp.read(1)
483        if c == b'\x00':
484            return res.decode('utf-8')
485        else:
486            res += c
487
488if __name__ == "__main__":
489   results = {}
490
491   stdin = sys.stdin.buffer
492   while True:
493       packet_type = stdin.read(4)
494       if packet_type == b'':
495           break;
496
497       test_name = get_cstr(stdin)
498       test_variant = get_cstr(stdin)
499       if test_variant != '':
500           full_name = test_name + '/' + test_variant
501       else:
502           full_name = test_name
503
504       test_source_file = get_cstr(stdin)
505       current_result = None
506       if ord(stdin.read(1)):
507           current_result = (get_cstr(stdin), get_cstr(stdin))
508       code_size = struct.unpack("=L", stdin.read(4))[0]
509       code = stdin.read(code_size).decode('utf-8')
510
511       results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result)
512
513   result_types = ['passed', 'failed', 'todo', 'empty']
514   num_expected = 0
515   num_unexpected = 0
516   for t in result_types:
517       num_expected += print_results(results, t, True)
518   for t in result_types:
519       num_unexpected += print_results(results, t, False)
520   num_expected_skipped = print_results(results, 'skipped', True)
521   num_unexpected_skipped = print_results(results, 'skipped', False)
522
523   num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped
524   color = set_red if num_unexpected else set_green
525   print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal))
526   if num_unexpected_skipped:
527       print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal))
528
529   if num_unexpected:
530       sys.exit(1)
531