• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// test-runner runs tint against a number of test shaders checking for expected behavior
16package main
17
18import (
19	"context"
20	"flag"
21	"fmt"
22	"io/ioutil"
23	"os"
24	"os/exec"
25	"path/filepath"
26	"regexp"
27	"runtime"
28	"sort"
29	"strings"
30	"time"
31	"unicode/utf8"
32
33	"dawn.googlesource.com/tint/tools/src/fileutils"
34	"dawn.googlesource.com/tint/tools/src/glob"
35	"github.com/fatih/color"
36	"github.com/sergi/go-diff/diffmatchpatch"
37)
38
39type outputFormat string
40
41const (
42	testTimeout = 30 * time.Second
43
44	glsl   = outputFormat("glsl")
45	hlsl   = outputFormat("hlsl")
46	msl    = outputFormat("msl")
47	spvasm = outputFormat("spvasm")
48	wgsl   = outputFormat("wgsl")
49)
50
51func main() {
52	if err := run(); err != nil {
53		fmt.Println(err)
54		os.Exit(1)
55	}
56}
57
58func showUsage() {
59	fmt.Println(`
60test-runner runs tint against a number of test shaders checking for expected behavior
61
62usage:
63  test-runner [flags...] <executable> [<directory>]
64
65  <executable> the path to the tint executable
66  <directory>  the root directory of the test files
67
68optional flags:`)
69	flag.PrintDefaults()
70	fmt.Println(``)
71	os.Exit(1)
72}
73
74func run() error {
75	var formatList, filter, dxcPath, xcrunPath string
76	var maxFilenameColumnWidth int
77	numCPU := runtime.NumCPU()
78	fxc, verbose, generateExpected, generateSkip := false, false, false, false
79	flag.StringVar(&formatList, "format", "wgsl,spvasm,msl,hlsl", "comma separated list of formats to emit. Possible values are: all, wgsl, spvasm, msl, hlsl, glsl")
80	flag.StringVar(&filter, "filter", "**.wgsl, **.spvasm, **.spv", "comma separated list of glob patterns for test files")
81	flag.StringVar(&dxcPath, "dxc", "", "path to DXC executable for validating HLSL output")
82	flag.StringVar(&xcrunPath, "xcrun", "", "path to xcrun executable for validating MSL output")
83	flag.BoolVar(&fxc, "fxc", false, "validate with FXC instead of DXC")
84	flag.BoolVar(&verbose, "verbose", false, "print all run tests, including rows that all pass")
85	flag.BoolVar(&generateExpected, "generate-expected", false, "create or update all expected outputs")
86	flag.BoolVar(&generateSkip, "generate-skip", false, "create or update all expected outputs that fail with SKIP")
87	flag.IntVar(&numCPU, "j", numCPU, "maximum number of concurrent threads to run tests")
88	flag.IntVar(&maxFilenameColumnWidth, "filename-column-width", 0, "maximum width of the filename column")
89	flag.Usage = showUsage
90	flag.Parse()
91
92	args := flag.Args()
93	if len(args) == 0 {
94		showUsage()
95	}
96
97	// executable path is the first argument
98	exe, args := args[0], args[1:]
99
100	// (optional) target directory is the second argument
101	dir := "."
102	if len(args) > 0 {
103		dir, args = args[0], args[1:]
104	}
105
106	// Check the executable can be found and actually is executable
107	if !fileutils.IsExe(exe) {
108		return fmt.Errorf("'%s' not found or is not executable", exe)
109	}
110	exe, err := filepath.Abs(exe)
111	if err != nil {
112		return err
113	}
114
115	// Allow using '/' in the filter on Windows
116	filter = strings.ReplaceAll(filter, "/", string(filepath.Separator))
117
118	// Split the --filter flag up by ',', trimming any whitespace at the start and end
119	globIncludes := strings.Split(filter, ",")
120	for i, s := range globIncludes {
121		s = filepath.ToSlash(s) // Replace '\' with '/'
122		globIncludes[i] = `"` + strings.TrimSpace(s) + `"`
123	}
124
125	// Glob the files to test
126	files, err := glob.Scan(dir, glob.MustParseConfig(`{
127		"paths": [
128			{
129				"include": [ `+strings.Join(globIncludes, ",")+` ]
130			},
131			{
132				"exclude": [
133					"**.expected.wgsl",
134					"**.expected.spvasm",
135					"**.expected.msl",
136					"**.expected.hlsl"
137				]
138			}
139		]
140	}`))
141	if err != nil {
142		return fmt.Errorf("Failed to glob files: %w", err)
143	}
144
145	// Ensure the files are sorted (globbing should do this, but why not)
146	sort.Strings(files)
147
148	// Parse --format into a list of outputFormat
149	formats := []outputFormat{}
150	if formatList == "all" {
151		formats = []outputFormat{wgsl, spvasm, msl, hlsl, glsl}
152	} else {
153		for _, f := range strings.Split(formatList, ",") {
154			switch strings.TrimSpace(f) {
155			case "wgsl":
156				formats = append(formats, wgsl)
157			case "spvasm":
158				formats = append(formats, spvasm)
159			case "msl":
160				formats = append(formats, msl)
161			case "hlsl":
162				formats = append(formats, hlsl)
163			case "glsl":
164				formats = append(formats, glsl)
165			default:
166				return fmt.Errorf("unknown format '%s'", f)
167			}
168		}
169	}
170
171	defaultMSLExe := "xcrun"
172	if runtime.GOOS == "windows" {
173		defaultMSLExe = "metal.exe"
174	}
175
176	// If explicit verification compilers have been specified, check they exist.
177	// Otherwise, look on PATH for them, but don't error if they cannot be found.
178	for _, tool := range []struct {
179		name string
180		lang string
181		path *string
182	}{
183		{"dxc", "hlsl", &dxcPath},
184		{defaultMSLExe, "msl", &xcrunPath},
185	} {
186		if *tool.path == "" {
187			p, err := exec.LookPath(tool.name)
188			if err == nil && fileutils.IsExe(p) {
189				*tool.path = p
190			}
191		} else if !fileutils.IsExe(*tool.path) {
192			return fmt.Errorf("%v not found at '%v'", tool.name, *tool.path)
193		}
194
195		color.Set(color.FgCyan)
196		fmt.Printf("%-4s", tool.lang)
197		color.Unset()
198		fmt.Printf(" validation ")
199		if *tool.path != "" || (fxc && tool.lang == "hlsl") {
200			color.Set(color.FgGreen)
201			tool_path := *tool.path
202			if fxc && tool.lang == "hlsl" {
203				tool_path = "Tint will use FXC dll in PATH"
204			}
205			fmt.Printf("ENABLED (" + tool_path + ")")
206		} else {
207			color.Set(color.FgRed)
208			fmt.Printf("DISABLED")
209		}
210		color.Unset()
211		fmt.Println()
212	}
213	fmt.Println()
214
215	// Build the list of results.
216	// These hold the chans used to report the job results.
217	results := make([]map[outputFormat]chan status, len(files))
218	for i := range files {
219		fileResults := map[outputFormat]chan status{}
220		for _, format := range formats {
221			fileResults[format] = make(chan status, 1)
222		}
223		results[i] = fileResults
224	}
225
226	pendingJobs := make(chan job, 256)
227
228	// Spawn numCPU job runners...
229	for cpu := 0; cpu < numCPU; cpu++ {
230		go func() {
231			for job := range pendingJobs {
232				job.run(dir, exe, fxc, dxcPath, xcrunPath, generateExpected, generateSkip)
233			}
234		}()
235	}
236
237	// Issue the jobs...
238	go func() {
239		for i, file := range files { // For each test file...
240			file := filepath.Join(dir, file)
241			flags := parseFlags(file)
242			for _, format := range formats { // For each output format...
243				pendingJobs <- job{
244					file:   file,
245					flags:  flags,
246					format: format,
247					result: results[i][format],
248				}
249			}
250		}
251		close(pendingJobs)
252	}()
253
254	type failure struct {
255		file   string
256		format outputFormat
257		err    error
258	}
259
260	type stats struct {
261		numTests, numPass, numSkip, numFail int
262		timeTaken                           time.Duration
263	}
264
265	// Statistics per output format
266	statsByFmt := map[outputFormat]*stats{}
267	for _, format := range formats {
268		statsByFmt[format] = &stats{}
269	}
270
271	// Print the table of file x format and gather per-format stats
272	failures := []failure{}
273	filenameColumnWidth := maxStringLen(files)
274	if maxFilenameColumnWidth > 0 {
275		filenameColumnWidth = maxFilenameColumnWidth
276	}
277
278	red := color.New(color.FgRed)
279	green := color.New(color.FgGreen)
280	yellow := color.New(color.FgYellow)
281	cyan := color.New(color.FgCyan)
282
283	printFormatsHeader := func() {
284		fmt.Printf(strings.Repeat(" ", filenameColumnWidth))
285		fmt.Printf(" ┃ ")
286		for _, format := range formats {
287			cyan.Printf(alignCenter(format, formatWidth(format)))
288			fmt.Printf(" │ ")
289		}
290		fmt.Println()
291	}
292	printHorizontalLine := func() {
293		fmt.Printf(strings.Repeat("━", filenameColumnWidth))
294		fmt.Printf("━╋━")
295		for _, format := range formats {
296			fmt.Printf(strings.Repeat("━", formatWidth(format)))
297			fmt.Printf("━┿━")
298		}
299		fmt.Println()
300	}
301
302	fmt.Println()
303
304	printFormatsHeader()
305	printHorizontalLine()
306
307	for i, file := range files {
308		results := results[i]
309
310		row := &strings.Builder{}
311		rowAllPassed := true
312
313		filenameLength := utf8.RuneCountInString(file)
314		shortFile := file
315		if filenameLength > filenameColumnWidth {
316			shortFile = "..." + file[filenameLength-filenameColumnWidth+3:]
317		}
318
319		fmt.Fprintf(row, alignRight(shortFile, filenameColumnWidth))
320		fmt.Fprintf(row, " ┃ ")
321		for _, format := range formats {
322			columnWidth := formatWidth(format)
323			result := <-results[format]
324			stats := statsByFmt[format]
325			stats.numTests++
326			stats.timeTaken += result.timeTaken
327			if err := result.err; err != nil {
328				failures = append(failures, failure{
329					file: file, format: format, err: err,
330				})
331			}
332			switch result.code {
333			case pass:
334				green.Fprintf(row, alignCenter("PASS", columnWidth))
335				stats.numPass++
336			case fail:
337				red.Fprintf(row, alignCenter("FAIL", columnWidth))
338				rowAllPassed = false
339				stats.numFail++
340			case skip:
341				yellow.Fprintf(row, alignCenter("SKIP", columnWidth))
342				rowAllPassed = false
343				stats.numSkip++
344			default:
345				fmt.Fprintf(row, alignCenter(result.code, columnWidth))
346				rowAllPassed = false
347			}
348			fmt.Fprintf(row, " │ ")
349		}
350
351		if verbose || !rowAllPassed {
352			fmt.Fprintln(color.Output, row)
353		}
354	}
355
356	printHorizontalLine()
357	printFormatsHeader()
358	printHorizontalLine()
359	printStat := func(col *color.Color, name string, num func(*stats) int) {
360		row := &strings.Builder{}
361		anyNonZero := false
362		for _, format := range formats {
363			columnWidth := formatWidth(format)
364			count := num(statsByFmt[format])
365			if count > 0 {
366				col.Fprintf(row, alignLeft(count, columnWidth))
367				anyNonZero = true
368			} else {
369				fmt.Fprintf(row, alignLeft(count, columnWidth))
370			}
371			fmt.Fprintf(row, " │ ")
372		}
373
374		if !anyNonZero {
375			return
376		}
377		col.Printf(alignRight(name, filenameColumnWidth))
378		fmt.Printf(" ┃ ")
379		fmt.Fprintln(color.Output, row)
380
381		col.Printf(strings.Repeat(" ", filenameColumnWidth))
382		fmt.Printf(" ┃ ")
383		for _, format := range formats {
384			columnWidth := formatWidth(format)
385			stats := statsByFmt[format]
386			count := num(stats)
387			percent := percentage(count, stats.numTests)
388			if count > 0 {
389				col.Print(alignRight(percent, columnWidth))
390			} else {
391				fmt.Print(alignRight(percent, columnWidth))
392			}
393			fmt.Printf(" │ ")
394		}
395		fmt.Println()
396	}
397	printStat(green, "PASS", func(s *stats) int { return s.numPass })
398	printStat(yellow, "SKIP", func(s *stats) int { return s.numSkip })
399	printStat(red, "FAIL", func(s *stats) int { return s.numFail })
400
401	cyan.Printf(alignRight("TIME", filenameColumnWidth))
402	fmt.Printf(" ┃ ")
403	for _, format := range formats {
404		timeTaken := printDuration(statsByFmt[format].timeTaken)
405		cyan.Printf(alignLeft(timeTaken, formatWidth(format)))
406		fmt.Printf(" │ ")
407	}
408	fmt.Println()
409
410	for _, f := range failures {
411		color.Set(color.FgBlue)
412		fmt.Printf("%s ", f.file)
413		color.Set(color.FgCyan)
414		fmt.Printf("%s ", f.format)
415		color.Set(color.FgRed)
416		fmt.Println("FAIL")
417		color.Unset()
418		fmt.Println(indent(f.err.Error(), 4))
419	}
420	if len(failures) > 0 {
421		fmt.Println()
422	}
423
424	allStats := stats{}
425	for _, format := range formats {
426		stats := statsByFmt[format]
427		allStats.numTests += stats.numTests
428		allStats.numPass += stats.numPass
429		allStats.numSkip += stats.numSkip
430		allStats.numFail += stats.numFail
431	}
432
433	fmt.Printf("%d tests run", allStats.numTests)
434	if allStats.numPass > 0 {
435		fmt.Printf(", ")
436		color.Set(color.FgGreen)
437		fmt.Printf("%d tests pass", allStats.numPass)
438		color.Unset()
439	} else {
440		fmt.Printf(", %d tests pass", allStats.numPass)
441	}
442	if allStats.numSkip > 0 {
443		fmt.Printf(", ")
444		color.Set(color.FgYellow)
445		fmt.Printf("%d tests skipped", allStats.numSkip)
446		color.Unset()
447	} else {
448		fmt.Printf(", %d tests skipped", allStats.numSkip)
449	}
450	if allStats.numFail > 0 {
451		fmt.Printf(", ")
452		color.Set(color.FgRed)
453		fmt.Printf("%d tests failed", allStats.numFail)
454		color.Unset()
455	} else {
456		fmt.Printf(", %d tests failed", allStats.numFail)
457	}
458	fmt.Println()
459	fmt.Println()
460
461	if allStats.numFail > 0 {
462		os.Exit(1)
463	}
464
465	return nil
466}
467
468// Structures to hold the results of the tests
469type statusCode string
470
471const (
472	fail statusCode = "FAIL"
473	pass statusCode = "PASS"
474	skip statusCode = "SKIP"
475)
476
477type status struct {
478	code      statusCode
479	err       error
480	timeTaken time.Duration
481}
482
483type job struct {
484	file   string
485	flags  []string
486	format outputFormat
487	result chan status
488}
489
490func (j job) run(wd, exe string, fxc bool, dxcPath, xcrunPath string, generateExpected, generateSkip bool) {
491	j.result <- func() status {
492		// Is there an expected output?
493		expected := loadExpectedFile(j.file, j.format)
494		skipped := false
495		if strings.HasPrefix(expected, "SKIP") { // Special SKIP token
496			skipped = true
497		}
498
499		expected = strings.ReplaceAll(expected, "\r\n", "\n")
500
501		file, err := filepath.Rel(wd, j.file)
502		if err != nil {
503			file = j.file
504		}
505
506		// Make relative paths use forward slash separators (on Windows) so that paths in tint
507		// output match expected output that contain errors
508		file = strings.ReplaceAll(file, `\`, `/`)
509
510		args := []string{
511			file,
512			"--format", string(j.format),
513		}
514
515		// Can we validate?
516		validate := false
517		switch j.format {
518		case wgsl:
519			validate = true
520		case spvasm, glsl:
521			args = append(args, "--validate") // spirv-val and glslang are statically linked, always available
522			validate = true
523		case hlsl:
524			if fxc {
525				args = append(args, "--fxc")
526				validate = true
527			} else if dxcPath != "" {
528				args = append(args, "--dxc", dxcPath)
529				validate = true
530			}
531		case msl:
532			if xcrunPath != "" {
533				args = append(args, "--xcrun", xcrunPath)
534				validate = true
535			}
536		}
537
538		args = append(args, j.flags...)
539
540		// Invoke the compiler...
541		start := time.Now()
542		ok, out := invoke(wd, exe, args...)
543		timeTaken := time.Since(start)
544
545		out = strings.ReplaceAll(out, "\r\n", "\n")
546		matched := expected == "" || expected == out
547
548		if ok && generateExpected && (validate || !skipped) {
549			saveExpectedFile(j.file, j.format, out)
550			matched = true
551		}
552
553		switch {
554		case ok && matched:
555			// Test passed
556			return status{code: pass, timeTaken: timeTaken}
557
558			//       --- Below this point the test has failed ---
559
560		case skipped:
561			if generateSkip {
562				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
563			}
564			return status{code: skip, timeTaken: timeTaken}
565
566		case !ok:
567			// Compiler returned non-zero exit code
568			if generateSkip {
569				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
570			}
571			err := fmt.Errorf("%s", out)
572			return status{code: fail, err: err, timeTaken: timeTaken}
573
574		default:
575			// Compiler returned zero exit code, or output was not as expected
576			if generateSkip {
577				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
578			}
579
580			// Expected output did not match
581			dmp := diffmatchpatch.New()
582			diff := dmp.DiffPrettyText(dmp.DiffMain(expected, out, true))
583			err := fmt.Errorf(`Output was not as expected
584
585--------------------------------------------------------------------------------
586-- Expected:                                                                  --
587--------------------------------------------------------------------------------
588%s
589
590--------------------------------------------------------------------------------
591-- Got:                                                                       --
592--------------------------------------------------------------------------------
593%s
594
595--------------------------------------------------------------------------------
596-- Diff:                                                                      --
597--------------------------------------------------------------------------------
598%s`,
599				expected, out, diff)
600			return status{code: fail, err: err, timeTaken: timeTaken}
601		}
602	}()
603}
604
605// loadExpectedFile loads the expected output file for the test file at 'path'
606// and the output format 'format'. If the file does not exist, or cannot be
607// read, then an empty string is returned.
608func loadExpectedFile(path string, format outputFormat) string {
609	content, err := ioutil.ReadFile(expectedFilePath(path, format))
610	if err != nil {
611		return ""
612	}
613	return string(content)
614}
615
616// saveExpectedFile writes the expected output file for the test file at 'path'
617// and the output format 'format', with the content 'content'.
618func saveExpectedFile(path string, format outputFormat, content string) error {
619	// Don't generate expected results for certain directories that contain
620	// large corpora of tests for which the generated code is uninteresting.
621	for _, exclude := range []string{"/test/unittest/", "/test/vk-gl-cts/"} {
622		if strings.Contains(path, filepath.FromSlash(exclude)) {
623			return nil
624		}
625	}
626	return ioutil.WriteFile(expectedFilePath(path, format), []byte(content), 0666)
627}
628
629// expectedFilePath returns the expected output file path for the test file at
630// 'path' and the output format 'format'.
631func expectedFilePath(path string, format outputFormat) string {
632	return path + ".expected." + string(format)
633}
634
635// indent returns the string 's' indented with 'n' whitespace characters
636func indent(s string, n int) string {
637	tab := strings.Repeat(" ", n)
638	return tab + strings.ReplaceAll(s, "\n", "\n"+tab)
639}
640
641// alignLeft returns the string of 'val' padded so that it is aligned left in
642// a column of the given width
643func alignLeft(val interface{}, width int) string {
644	s := fmt.Sprint(val)
645	padding := width - utf8.RuneCountInString(s)
646	if padding < 0 {
647		return s
648	}
649	return s + strings.Repeat(" ", padding)
650}
651
652// alignCenter returns the string of 'val' padded so that it is centered in a
653// column of the given width.
654func alignCenter(val interface{}, width int) string {
655	s := fmt.Sprint(val)
656	padding := width - utf8.RuneCountInString(s)
657	if padding < 0 {
658		return s
659	}
660	return strings.Repeat(" ", padding/2) + s + strings.Repeat(" ", (padding+1)/2)
661}
662
663// alignRight returns the string of 'val' padded so that it is aligned right in
664// a column of the given width
665func alignRight(val interface{}, width int) string {
666	s := fmt.Sprint(val)
667	padding := width - utf8.RuneCountInString(s)
668	if padding < 0 {
669		return s
670	}
671	return strings.Repeat(" ", padding) + s
672}
673
674// maxStringLen returns the maximum number of runes found in all the strings in
675// 'l'
676func maxStringLen(l []string) int {
677	max := 0
678	for _, s := range l {
679		if c := utf8.RuneCountInString(s); c > max {
680			max = c
681		}
682	}
683	return max
684}
685
686// formatWidth returns the width in runes for the outputFormat column 'b'
687func formatWidth(b outputFormat) int {
688	const min = 6
689	c := utf8.RuneCountInString(string(b))
690	if c < min {
691		return min
692	}
693	return c
694}
695
696// percentage returns the percentage of n out of total as a string
697func percentage(n, total int) string {
698	if total == 0 {
699		return "-"
700	}
701	f := float64(n) / float64(total)
702	return fmt.Sprintf("%.1f%c", f*100.0, '%')
703}
704
705// invoke runs the executable 'exe' with the provided arguments.
706func invoke(wd, exe string, args ...string) (ok bool, output string) {
707	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
708	defer cancel()
709
710	cmd := exec.CommandContext(ctx, exe, args...)
711	cmd.Dir = wd
712	out, err := cmd.CombinedOutput()
713	str := string(out)
714	if err != nil {
715		if ctx.Err() == context.DeadlineExceeded {
716			return false, fmt.Sprintf("test timed out after %v", testTimeout)
717		}
718		if str != "" {
719			return false, str
720		}
721		return false, err.Error()
722	}
723	return true, str
724}
725
726var reFlags = regexp.MustCompile(` *\/\/ *flags:(.*)\n`)
727
728// parseFlags looks for a `// flags:` header at the start of the file with the
729// given path, returning each of the space delimited tokens that follow for the
730// line
731func parseFlags(path string) []string {
732	content, err := ioutil.ReadFile(path)
733	if err != nil {
734		return nil
735	}
736	header := strings.SplitN(string(content), "\n", 1)[0]
737	m := reFlags.FindStringSubmatch(header)
738	if len(m) != 2 {
739		return nil
740	}
741	return strings.Split(m[1], " ")
742}
743
744func printDuration(d time.Duration) string {
745	sec := int(d.Seconds())
746	min := int(sec) / 60
747	hour := min / 60
748	min -= hour * 60
749	sec -= min * 60
750	sb := &strings.Builder{}
751	if hour > 0 {
752		fmt.Fprintf(sb, "%dh", hour)
753	}
754	if min > 0 {
755		fmt.Fprintf(sb, "%dm", min)
756	}
757	if sec > 0 {
758		fmt.Fprintf(sb, "%ds", sec)
759	}
760	return sb.String()
761}
762