• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (c) 2020 Valve Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22import re
23import sys
24import os.path
25import struct
26import string
27import copy
28from math import floor
29
30if os.isatty(sys.stdout.fileno()):
31    set_red = "\033[31m"
32    set_green = "\033[1;32m"
33    set_normal = "\033[0m"
34else:
35    set_red = ''
36    set_green = ''
37    set_normal = ''
38
39initial_code = '''
40def insert_code(code):
41    insert_queue.append(CodeCheck(code))
42
43def insert_pattern(pattern):
44    insert_queue.append(PatternCheck(pattern))
45
46def vector_gpr(prefix, name, size, align):
47    insert_code(f'{name} = {name}0')
48    for i in range(size):
49        insert_code(f'{name}{i} = {name}0 + {i}')
50    insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}')
51    insert_code(f'success = {name}0 % {align} == 0')
52    return f'{prefix}[#{name}0:#{name}{size - 1}]'
53
54def sgpr_vector(name, size, align):
55    return vector_gpr('s', name, size, align)
56
57funcs.update({
58    's64': lambda name: vector_gpr('s', name, 2, 2),
59    's96': lambda name: vector_gpr('s', name, 3, 2),
60    's128': lambda name: vector_gpr('s', name, 4, 4),
61    's256': lambda name: vector_gpr('s', name, 8, 4),
62    's512': lambda name: vector_gpr('s', name, 16, 4),
63})
64for i in range(2, 14):
65    funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1)
66'''
67
68class Check:
69    def __init__(self, data):
70        self.data = data.rstrip()
71
72    def run(self, state):
73        pass
74
75class CodeCheck(Check):
76    def run(self, state):
77        indent = 0
78        first_line = [l for l in self.data.split('\n') if l.strip() != ''][0]
79        indent_amount = len(first_line) - len(first_line.lstrip())
80        indent = first_line[:indent_amount]
81        new_lines = []
82        for line in self.data.split('\n'):
83            if line.strip() == '':
84                new_lines.append('')
85                continue
86            if line[:indent_amount] != indent:
87                state.result.log += 'unexpected indent in code check:\n'
88                state.result.log += self.data + '\n'
89                return False
90            new_lines.append(line[indent_amount:])
91        code = '\n'.join(new_lines)
92
93        try:
94            exec(code, state.g)
95            state.result.log += state.g['log']
96            state.g['log'] = ''
97        except BaseException as e:
98            state.result.log += 'code check raised exception:\n'
99            state.result.log += code + '\n'
100            state.result.log += str(e)
101            return False
102        if not state.g['success']:
103            state.result.log += 'code check failed:\n'
104            state.result.log += code + '\n'
105            return False
106        return True
107
108class StringStream:
109    class Pos:
110        def __init__(self):
111            self.line = 1
112            self.column = 1
113
114    def __init__(self, data, name):
115        self.name = name
116        self.data = data
117        self.offset = 0
118        self.pos = StringStream.Pos()
119
120    def reset(self):
121        self.offset = 0
122        self.pos = StringStream.Pos()
123
124    def peek(self, num=1):
125        return self.data[self.offset:self.offset+num]
126
127    def peek_test(self, chars):
128        c = self.peek(1)
129        return c != '' and c in chars
130
131    def read(self, num=4294967296):
132        res = self.peek(num)
133        self.offset += len(res)
134        for c in res:
135            if c == '\n':
136                self.pos.line += 1
137                self.pos.column = 1
138            else:
139                self.pos.column += 1
140        return res
141
142    def get_line(self, num):
143        return self.data.split('\n')[num - 1].rstrip()
144
145    def skip_line(self):
146        while self.peek(1) not in ['\n', '']:
147            self.read(1)
148        self.read(1)
149
150    def skip_whitespace(self, inc_line):
151        chars = [' ', '\t'] + (['\n'] if inc_line else [])
152        while self.peek(1) in chars:
153            self.read(1)
154
155    def get_number(self):
156        num = ''
157        while self.peek() in string.digits:
158            num += self.read(1)
159        return num
160
161    def check_identifier(self):
162        return self.peek_test(string.ascii_letters + '_')
163
164    def get_identifier(self):
165        res = ''
166        if self.check_identifier():
167            while self.peek_test(string.ascii_letters + string.digits + '_'):
168                res += self.read(1)
169        return res
170
171def format_error_lines(at, line_num, column_num, ctx, line):
172    pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx)
173    return [pred + line + '"',
174            '-' * (column_num - 1 + len(pred)) + '^']
175
176class MatchResult:
177    def __init__(self, pattern):
178        self.success = True
179        self.func_res = None
180        self.pattern = pattern
181        self.pattern_pos = StringStream.Pos()
182        self.output_pos = StringStream.Pos()
183        self.fail_message = ''
184
185    def set_pos(self, pattern, output):
186        self.pattern_pos.line = pattern.pos.line
187        self.pattern_pos.column = pattern.pos.column
188        self.output_pos.line = output.pos.line
189        self.output_pos.column = output.pos.column
190
191    def fail(self, msg):
192        self.success = False
193        self.fail_message = msg
194
195    def format_pattern_pos(self):
196        pat_pos = self.pattern_pos
197        pat_line = self.pattern.get_line(pat_pos.line)
198        res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line)
199        func_res = self.func_res
200        while func_res:
201            pat_pos = func_res.pattern_pos
202            pat_line = func_res.pattern.get_line(pat_pos.line)
203            res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line)
204            func_res = func_res.func_res
205        return '\n'.join(res)
206
207def do_match(g, pattern, output, skip_lines, in_func=False):
208    assert(not in_func or not skip_lines)
209
210    if not in_func:
211        output.skip_whitespace(False)
212    pattern.skip_whitespace(False)
213
214    old_g = copy.copy(g)
215    old_g_keys = list(g.keys())
216    res = MatchResult(pattern)
217    escape = False
218    while True:
219        res.set_pos(pattern, output)
220
221        c = pattern.read(1)
222        fail = False
223        if c == '':
224            break
225        elif output.peek() == '':
226            res.fail('unexpected end of output')
227        elif c == '\\':
228            escape = True
229            continue
230        elif c == '\n':
231            old_line = output.pos.line
232            output.skip_whitespace(True)
233            if output.pos.line == old_line:
234                res.fail('expected newline in output')
235        elif not escape and c == '#':
236            num = output.get_number()
237            if num == '':
238                res.fail('expected number in output')
239            elif pattern.check_identifier():
240                name = pattern.get_identifier()
241                if name in g and int(num) != g[name]:
242                    res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
243                elif name != '_':
244                    g[name] = int(num)
245        elif not escape and c == '$':
246            name = pattern.get_identifier()
247
248            val = ''
249            while not output.peek_test(string.whitespace):
250                val += output.read(1)
251
252            if name in g and val != g[name]:
253                res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name]))
254            elif name != '_':
255                g[name] = val
256        elif not escape and c == '%' and pattern.check_identifier():
257            if output.read(1) != '%':
258                res.fail('expected \'%\' in output')
259            else:
260                num = output.get_number()
261                if num == '':
262                    res.fail('expected number in output')
263                else:
264                    name = pattern.get_identifier()
265                    if name in g and int(num) != g[name]:
266                        res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
267                    elif name != '_':
268                        g[name] = int(num)
269        elif not escape and c == '@' and pattern.check_identifier():
270            name = pattern.get_identifier()
271            args = ''
272            if pattern.peek_test('('):
273                pattern.read(1)
274                while pattern.peek() not in ['', ')']:
275                    args += pattern.read(1)
276                assert(pattern.read(1) == ')')
277            func_res = g['funcs'][name](args)
278            match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True)
279            if not match_res.success:
280                res.func_res = match_res
281                res.output_pos = match_res.output_pos
282                res.fail(match_res.fail_message)
283        elif not escape and c == ' ':
284            while pattern.peek_test(' '):
285                pattern.read(1)
286
287            read_whitespace = False
288            while output.peek_test(' \t'):
289                output.read(1)
290                read_whitespace = True
291            if not read_whitespace:
292                res.fail('expected whitespace in output, got %r' % (output.peek(1)))
293        else:
294            outc = output.peek(1)
295            if outc != c:
296                res.fail('expected %r in output, got %r' % (c, outc))
297            else:
298                output.read(1)
299        if not res.success:
300            if skip_lines and output.peek() != '':
301                g.clear()
302                g.update(old_g)
303                res.success = True
304                output.skip_line()
305                pattern.reset()
306                output.skip_whitespace(False)
307                pattern.skip_whitespace(False)
308            else:
309                return res
310
311        escape = False
312
313    if not in_func:
314        while output.peek() in [' ', '\t']:
315            output.read(1)
316
317        if output.read(1) not in ['', '\n']:
318            res.fail('expected end of output')
319            return res
320
321    return res
322
323class PatternCheck(Check):
324    def __init__(self, data, search, position):
325        Check.__init__(self, data)
326        self.search = search
327        self.position = position
328
329    def run(self, state):
330        pattern_stream = StringStream(self.data.rstrip(), 'pattern')
331        res = do_match(state.g, pattern_stream, state.g['output'], self.search)
332        if not res.success:
333            state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message)
334            state.result.log += res.format_pattern_pos() + '\n\n'
335            if not self.search:
336                out_line = state.g['output'].get_line(res.output_pos.line)
337                state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line))
338            else:
339                state.result.log += 'output was:\n'
340                state.result.log += state.g['output'].data.rstrip() + '\n'
341            return False
342        return True
343
344class CheckState:
345    def __init__(self, result, variant, checks, output):
346        self.result = result
347        self.variant = variant
348        self.checks = checks
349
350        self.checks.insert(0, CodeCheck(initial_code))
351        self.insert_queue = []
352
353        self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue,
354                  'variant': variant, 'log': '', 'output': StringStream(output, 'output'),
355                  'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck}
356
357class TestResult:
358    def __init__(self, expected):
359        self.result = ''
360        self.expected = expected
361        self.log = ''
362
363def check_output(result, variant, checks, output):
364    state = CheckState(result, variant, checks, output)
365
366    while len(state.checks):
367        check = state.checks.pop(0)
368        if not check.run(state):
369            result.result = 'failed'
370            return
371
372        for check in state.insert_queue[::-1]:
373            state.checks.insert(0, check)
374        state.insert_queue.clear()
375
376    result.result = 'passed'
377    return
378
379def parse_check(variant, line, checks, pos):
380    if line.startswith(';'):
381        line = line[1:]
382        if len(checks) and isinstance(checks[-1], CodeCheck):
383            checks[-1].data += '\n' + line
384        else:
385            checks.append(CodeCheck(line))
386    elif line.startswith('!'):
387        checks.append(PatternCheck(line[1:], False, pos))
388    elif line.startswith('>>'):
389        checks.append(PatternCheck(line[2:], True, pos))
390    elif line.startswith('~'):
391        end = len(line)
392        start = len(line)
393        for c in [';', '!', '>>']:
394            if line.find(c) != -1 and line.find(c) < end:
395                end = line.find(c)
396        if end != len(line):
397            match = re.match(line[1:end], variant)
398            if match and match.end() == len(variant):
399                parse_check(variant, line[end:], checks, pos)
400
401def parse_test_source(test_name, variant, fname):
402    in_test = False
403    test = []
404    expected_result = 'passed'
405    line_num = 1
406    for line in open(fname, 'r').readlines():
407        if line.startswith('BEGIN_TEST(%s)' % test_name):
408            in_test = True
409        elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name):
410            in_test = True
411            expected_result = 'todo'
412        elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name):
413            in_test = True
414            expected_result = 'failed'
415        elif line.startswith('END_TEST'):
416            in_test = False
417        elif in_test:
418            test.append((line_num, line.strip()))
419        line_num += 1
420
421    checks = []
422    for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]:
423         parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1]))
424
425    return checks, expected_result
426
427def parse_and_check_test(test_name, variant, test_file, output, current_result):
428    checks, expected = parse_test_source(test_name, variant, test_file)
429
430    result = TestResult(expected)
431    if len(checks) == 0:
432        result.result = 'empty'
433        result.log = 'no checks found'
434    elif current_result != None:
435        result.result, result.log = current_result
436    else:
437        check_output(result, variant, checks, output)
438        if result.result == 'failed' and expected == 'todo':
439            result.result = 'todo'
440
441    return result
442
443def print_results(results, output, expected):
444    results = {name: result for name, result in results.items() if result.result == output}
445    results = {name: result for name, result in results.items() if (result.result == result.expected) == expected}
446
447    if not results:
448        return 0
449
450    print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected'))
451    for test, result in results.items():
452        color = '' if expected else set_red
453        print('   %s%s%s' % (color, test, set_normal))
454        if result.log.strip() != '':
455            for line in result.log.rstrip().split('\n'):
456                print('      ' + line.rstrip())
457    print('')
458
459    return len(results)
460
461def get_cstr(fp):
462    res = b''
463    while True:
464        c = fp.read(1)
465        if c == b'\x00':
466            return res.decode('utf-8')
467        else:
468            res += c
469
470if __name__ == "__main__":
471   results = {}
472
473   stdin = sys.stdin.buffer
474   while True:
475       packet_type = stdin.read(4)
476       if packet_type == b'':
477           break;
478
479       test_name = get_cstr(stdin)
480       test_variant = get_cstr(stdin)
481       if test_variant != '':
482           full_name = test_name + '/' + test_variant
483       else:
484           full_name = test_name
485
486       test_source_file = get_cstr(stdin)
487       current_result = None
488       if ord(stdin.read(1)):
489           current_result = (get_cstr(stdin), get_cstr(stdin))
490       code_size = struct.unpack("=L", stdin.read(4))[0]
491       code = stdin.read(code_size).decode('utf-8')
492
493       results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result)
494
495   result_types = ['passed', 'failed', 'todo', 'empty']
496   num_expected = 0
497   num_unexpected = 0
498   for t in result_types:
499       num_expected += print_results(results, t, True)
500   for t in result_types:
501       num_unexpected += print_results(results, t, False)
502   num_expected_skipped = print_results(results, 'skipped', True)
503   num_unexpected_skipped = print_results(results, 'skipped', False)
504
505   num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped
506   color = set_red if num_unexpected else set_green
507   print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal))
508   if num_unexpected_skipped:
509       print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal))
510
511   if num_unexpected:
512       sys.exit(1)
513