1// Copyright 2021 The Tint Authors. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// test-runner runs tint against a number of test shaders checking for expected behavior 16package main 17 18import ( 19 "context" 20 "flag" 21 "fmt" 22 "io/ioutil" 23 "os" 24 "os/exec" 25 "path/filepath" 26 "regexp" 27 "runtime" 28 "sort" 29 "strings" 30 "time" 31 "unicode/utf8" 32 33 "dawn.googlesource.com/tint/tools/src/fileutils" 34 "dawn.googlesource.com/tint/tools/src/glob" 35 "github.com/fatih/color" 36 "github.com/sergi/go-diff/diffmatchpatch" 37) 38 39type outputFormat string 40 41const ( 42 testTimeout = 30 * time.Second 43 44 glsl = outputFormat("glsl") 45 hlsl = outputFormat("hlsl") 46 msl = outputFormat("msl") 47 spvasm = outputFormat("spvasm") 48 wgsl = outputFormat("wgsl") 49) 50 51func main() { 52 if err := run(); err != nil { 53 fmt.Println(err) 54 os.Exit(1) 55 } 56} 57 58func showUsage() { 59 fmt.Println(` 60test-runner runs tint against a number of test shaders checking for expected behavior 61 62usage: 63 test-runner [flags...] <executable> [<directory>] 64 65 <executable> the path to the tint executable 66 <directory> the root directory of the test files 67 68optional flags:`) 69 flag.PrintDefaults() 70 fmt.Println(``) 71 os.Exit(1) 72} 73 74func run() error { 75 var formatList, filter, dxcPath, xcrunPath string 76 var maxFilenameColumnWidth int 77 numCPU := runtime.NumCPU() 78 fxc, verbose, generateExpected, generateSkip := false, false, false, false 79 flag.StringVar(&formatList, "format", "wgsl,spvasm,msl,hlsl", "comma separated list of formats to emit. Possible values are: all, wgsl, spvasm, msl, hlsl, glsl") 80 flag.StringVar(&filter, "filter", "**.wgsl, **.spvasm, **.spv", "comma separated list of glob patterns for test files") 81 flag.StringVar(&dxcPath, "dxc", "", "path to DXC executable for validating HLSL output") 82 flag.StringVar(&xcrunPath, "xcrun", "", "path to xcrun executable for validating MSL output") 83 flag.BoolVar(&fxc, "fxc", false, "validate with FXC instead of DXC") 84 flag.BoolVar(&verbose, "verbose", false, "print all run tests, including rows that all pass") 85 flag.BoolVar(&generateExpected, "generate-expected", false, "create or update all expected outputs") 86 flag.BoolVar(&generateSkip, "generate-skip", false, "create or update all expected outputs that fail with SKIP") 87 flag.IntVar(&numCPU, "j", numCPU, "maximum number of concurrent threads to run tests") 88 flag.IntVar(&maxFilenameColumnWidth, "filename-column-width", 0, "maximum width of the filename column") 89 flag.Usage = showUsage 90 flag.Parse() 91 92 args := flag.Args() 93 if len(args) == 0 { 94 showUsage() 95 } 96 97 // executable path is the first argument 98 exe, args := args[0], args[1:] 99 100 // (optional) target directory is the second argument 101 dir := "." 102 if len(args) > 0 { 103 dir, args = args[0], args[1:] 104 } 105 106 // Check the executable can be found and actually is executable 107 if !fileutils.IsExe(exe) { 108 return fmt.Errorf("'%s' not found or is not executable", exe) 109 } 110 exe, err := filepath.Abs(exe) 111 if err != nil { 112 return err 113 } 114 115 // Allow using '/' in the filter on Windows 116 filter = strings.ReplaceAll(filter, "/", string(filepath.Separator)) 117 118 // Split the --filter flag up by ',', trimming any whitespace at the start and end 119 globIncludes := strings.Split(filter, ",") 120 for i, s := range globIncludes { 121 s = filepath.ToSlash(s) // Replace '\' with '/' 122 globIncludes[i] = `"` + strings.TrimSpace(s) + `"` 123 } 124 125 // Glob the files to test 126 files, err := glob.Scan(dir, glob.MustParseConfig(`{ 127 "paths": [ 128 { 129 "include": [ `+strings.Join(globIncludes, ",")+` ] 130 }, 131 { 132 "exclude": [ 133 "**.expected.wgsl", 134 "**.expected.spvasm", 135 "**.expected.msl", 136 "**.expected.hlsl" 137 ] 138 } 139 ] 140 }`)) 141 if err != nil { 142 return fmt.Errorf("Failed to glob files: %w", err) 143 } 144 145 // Ensure the files are sorted (globbing should do this, but why not) 146 sort.Strings(files) 147 148 // Parse --format into a list of outputFormat 149 formats := []outputFormat{} 150 if formatList == "all" { 151 formats = []outputFormat{wgsl, spvasm, msl, hlsl, glsl} 152 } else { 153 for _, f := range strings.Split(formatList, ",") { 154 switch strings.TrimSpace(f) { 155 case "wgsl": 156 formats = append(formats, wgsl) 157 case "spvasm": 158 formats = append(formats, spvasm) 159 case "msl": 160 formats = append(formats, msl) 161 case "hlsl": 162 formats = append(formats, hlsl) 163 case "glsl": 164 formats = append(formats, glsl) 165 default: 166 return fmt.Errorf("unknown format '%s'", f) 167 } 168 } 169 } 170 171 defaultMSLExe := "xcrun" 172 if runtime.GOOS == "windows" { 173 defaultMSLExe = "metal.exe" 174 } 175 176 // If explicit verification compilers have been specified, check they exist. 177 // Otherwise, look on PATH for them, but don't error if they cannot be found. 178 for _, tool := range []struct { 179 name string 180 lang string 181 path *string 182 }{ 183 {"dxc", "hlsl", &dxcPath}, 184 {defaultMSLExe, "msl", &xcrunPath}, 185 } { 186 if *tool.path == "" { 187 p, err := exec.LookPath(tool.name) 188 if err == nil && fileutils.IsExe(p) { 189 *tool.path = p 190 } 191 } else if !fileutils.IsExe(*tool.path) { 192 return fmt.Errorf("%v not found at '%v'", tool.name, *tool.path) 193 } 194 195 color.Set(color.FgCyan) 196 fmt.Printf("%-4s", tool.lang) 197 color.Unset() 198 fmt.Printf(" validation ") 199 if *tool.path != "" || (fxc && tool.lang == "hlsl") { 200 color.Set(color.FgGreen) 201 tool_path := *tool.path 202 if fxc && tool.lang == "hlsl" { 203 tool_path = "Tint will use FXC dll in PATH" 204 } 205 fmt.Printf("ENABLED (" + tool_path + ")") 206 } else { 207 color.Set(color.FgRed) 208 fmt.Printf("DISABLED") 209 } 210 color.Unset() 211 fmt.Println() 212 } 213 fmt.Println() 214 215 // Build the list of results. 216 // These hold the chans used to report the job results. 217 results := make([]map[outputFormat]chan status, len(files)) 218 for i := range files { 219 fileResults := map[outputFormat]chan status{} 220 for _, format := range formats { 221 fileResults[format] = make(chan status, 1) 222 } 223 results[i] = fileResults 224 } 225 226 pendingJobs := make(chan job, 256) 227 228 // Spawn numCPU job runners... 229 for cpu := 0; cpu < numCPU; cpu++ { 230 go func() { 231 for job := range pendingJobs { 232 job.run(dir, exe, fxc, dxcPath, xcrunPath, generateExpected, generateSkip) 233 } 234 }() 235 } 236 237 // Issue the jobs... 238 go func() { 239 for i, file := range files { // For each test file... 240 file := filepath.Join(dir, file) 241 flags := parseFlags(file) 242 for _, format := range formats { // For each output format... 243 pendingJobs <- job{ 244 file: file, 245 flags: flags, 246 format: format, 247 result: results[i][format], 248 } 249 } 250 } 251 close(pendingJobs) 252 }() 253 254 type failure struct { 255 file string 256 format outputFormat 257 err error 258 } 259 260 type stats struct { 261 numTests, numPass, numSkip, numFail int 262 timeTaken time.Duration 263 } 264 265 // Statistics per output format 266 statsByFmt := map[outputFormat]*stats{} 267 for _, format := range formats { 268 statsByFmt[format] = &stats{} 269 } 270 271 // Print the table of file x format and gather per-format stats 272 failures := []failure{} 273 filenameColumnWidth := maxStringLen(files) 274 if maxFilenameColumnWidth > 0 { 275 filenameColumnWidth = maxFilenameColumnWidth 276 } 277 278 red := color.New(color.FgRed) 279 green := color.New(color.FgGreen) 280 yellow := color.New(color.FgYellow) 281 cyan := color.New(color.FgCyan) 282 283 printFormatsHeader := func() { 284 fmt.Printf(strings.Repeat(" ", filenameColumnWidth)) 285 fmt.Printf(" ┃ ") 286 for _, format := range formats { 287 cyan.Printf(alignCenter(format, formatWidth(format))) 288 fmt.Printf(" │ ") 289 } 290 fmt.Println() 291 } 292 printHorizontalLine := func() { 293 fmt.Printf(strings.Repeat("━", filenameColumnWidth)) 294 fmt.Printf("━╋━") 295 for _, format := range formats { 296 fmt.Printf(strings.Repeat("━", formatWidth(format))) 297 fmt.Printf("━┿━") 298 } 299 fmt.Println() 300 } 301 302 fmt.Println() 303 304 printFormatsHeader() 305 printHorizontalLine() 306 307 for i, file := range files { 308 results := results[i] 309 310 row := &strings.Builder{} 311 rowAllPassed := true 312 313 filenameLength := utf8.RuneCountInString(file) 314 shortFile := file 315 if filenameLength > filenameColumnWidth { 316 shortFile = "..." + file[filenameLength-filenameColumnWidth+3:] 317 } 318 319 fmt.Fprintf(row, alignRight(shortFile, filenameColumnWidth)) 320 fmt.Fprintf(row, " ┃ ") 321 for _, format := range formats { 322 columnWidth := formatWidth(format) 323 result := <-results[format] 324 stats := statsByFmt[format] 325 stats.numTests++ 326 stats.timeTaken += result.timeTaken 327 if err := result.err; err != nil { 328 failures = append(failures, failure{ 329 file: file, format: format, err: err, 330 }) 331 } 332 switch result.code { 333 case pass: 334 green.Fprintf(row, alignCenter("PASS", columnWidth)) 335 stats.numPass++ 336 case fail: 337 red.Fprintf(row, alignCenter("FAIL", columnWidth)) 338 rowAllPassed = false 339 stats.numFail++ 340 case skip: 341 yellow.Fprintf(row, alignCenter("SKIP", columnWidth)) 342 rowAllPassed = false 343 stats.numSkip++ 344 default: 345 fmt.Fprintf(row, alignCenter(result.code, columnWidth)) 346 rowAllPassed = false 347 } 348 fmt.Fprintf(row, " │ ") 349 } 350 351 if verbose || !rowAllPassed { 352 fmt.Fprintln(color.Output, row) 353 } 354 } 355 356 printHorizontalLine() 357 printFormatsHeader() 358 printHorizontalLine() 359 printStat := func(col *color.Color, name string, num func(*stats) int) { 360 row := &strings.Builder{} 361 anyNonZero := false 362 for _, format := range formats { 363 columnWidth := formatWidth(format) 364 count := num(statsByFmt[format]) 365 if count > 0 { 366 col.Fprintf(row, alignLeft(count, columnWidth)) 367 anyNonZero = true 368 } else { 369 fmt.Fprintf(row, alignLeft(count, columnWidth)) 370 } 371 fmt.Fprintf(row, " │ ") 372 } 373 374 if !anyNonZero { 375 return 376 } 377 col.Printf(alignRight(name, filenameColumnWidth)) 378 fmt.Printf(" ┃ ") 379 fmt.Fprintln(color.Output, row) 380 381 col.Printf(strings.Repeat(" ", filenameColumnWidth)) 382 fmt.Printf(" ┃ ") 383 for _, format := range formats { 384 columnWidth := formatWidth(format) 385 stats := statsByFmt[format] 386 count := num(stats) 387 percent := percentage(count, stats.numTests) 388 if count > 0 { 389 col.Print(alignRight(percent, columnWidth)) 390 } else { 391 fmt.Print(alignRight(percent, columnWidth)) 392 } 393 fmt.Printf(" │ ") 394 } 395 fmt.Println() 396 } 397 printStat(green, "PASS", func(s *stats) int { return s.numPass }) 398 printStat(yellow, "SKIP", func(s *stats) int { return s.numSkip }) 399 printStat(red, "FAIL", func(s *stats) int { return s.numFail }) 400 401 cyan.Printf(alignRight("TIME", filenameColumnWidth)) 402 fmt.Printf(" ┃ ") 403 for _, format := range formats { 404 timeTaken := printDuration(statsByFmt[format].timeTaken) 405 cyan.Printf(alignLeft(timeTaken, formatWidth(format))) 406 fmt.Printf(" │ ") 407 } 408 fmt.Println() 409 410 for _, f := range failures { 411 color.Set(color.FgBlue) 412 fmt.Printf("%s ", f.file) 413 color.Set(color.FgCyan) 414 fmt.Printf("%s ", f.format) 415 color.Set(color.FgRed) 416 fmt.Println("FAIL") 417 color.Unset() 418 fmt.Println(indent(f.err.Error(), 4)) 419 } 420 if len(failures) > 0 { 421 fmt.Println() 422 } 423 424 allStats := stats{} 425 for _, format := range formats { 426 stats := statsByFmt[format] 427 allStats.numTests += stats.numTests 428 allStats.numPass += stats.numPass 429 allStats.numSkip += stats.numSkip 430 allStats.numFail += stats.numFail 431 } 432 433 fmt.Printf("%d tests run", allStats.numTests) 434 if allStats.numPass > 0 { 435 fmt.Printf(", ") 436 color.Set(color.FgGreen) 437 fmt.Printf("%d tests pass", allStats.numPass) 438 color.Unset() 439 } else { 440 fmt.Printf(", %d tests pass", allStats.numPass) 441 } 442 if allStats.numSkip > 0 { 443 fmt.Printf(", ") 444 color.Set(color.FgYellow) 445 fmt.Printf("%d tests skipped", allStats.numSkip) 446 color.Unset() 447 } else { 448 fmt.Printf(", %d tests skipped", allStats.numSkip) 449 } 450 if allStats.numFail > 0 { 451 fmt.Printf(", ") 452 color.Set(color.FgRed) 453 fmt.Printf("%d tests failed", allStats.numFail) 454 color.Unset() 455 } else { 456 fmt.Printf(", %d tests failed", allStats.numFail) 457 } 458 fmt.Println() 459 fmt.Println() 460 461 if allStats.numFail > 0 { 462 os.Exit(1) 463 } 464 465 return nil 466} 467 468// Structures to hold the results of the tests 469type statusCode string 470 471const ( 472 fail statusCode = "FAIL" 473 pass statusCode = "PASS" 474 skip statusCode = "SKIP" 475) 476 477type status struct { 478 code statusCode 479 err error 480 timeTaken time.Duration 481} 482 483type job struct { 484 file string 485 flags []string 486 format outputFormat 487 result chan status 488} 489 490func (j job) run(wd, exe string, fxc bool, dxcPath, xcrunPath string, generateExpected, generateSkip bool) { 491 j.result <- func() status { 492 // Is there an expected output? 493 expected := loadExpectedFile(j.file, j.format) 494 skipped := false 495 if strings.HasPrefix(expected, "SKIP") { // Special SKIP token 496 skipped = true 497 } 498 499 expected = strings.ReplaceAll(expected, "\r\n", "\n") 500 501 file, err := filepath.Rel(wd, j.file) 502 if err != nil { 503 file = j.file 504 } 505 506 // Make relative paths use forward slash separators (on Windows) so that paths in tint 507 // output match expected output that contain errors 508 file = strings.ReplaceAll(file, `\`, `/`) 509 510 args := []string{ 511 file, 512 "--format", string(j.format), 513 } 514 515 // Can we validate? 516 validate := false 517 switch j.format { 518 case wgsl: 519 validate = true 520 case spvasm, glsl: 521 args = append(args, "--validate") // spirv-val and glslang are statically linked, always available 522 validate = true 523 case hlsl: 524 if fxc { 525 args = append(args, "--fxc") 526 validate = true 527 } else if dxcPath != "" { 528 args = append(args, "--dxc", dxcPath) 529 validate = true 530 } 531 case msl: 532 if xcrunPath != "" { 533 args = append(args, "--xcrun", xcrunPath) 534 validate = true 535 } 536 } 537 538 args = append(args, j.flags...) 539 540 // Invoke the compiler... 541 start := time.Now() 542 ok, out := invoke(wd, exe, args...) 543 timeTaken := time.Since(start) 544 545 out = strings.ReplaceAll(out, "\r\n", "\n") 546 matched := expected == "" || expected == out 547 548 if ok && generateExpected && (validate || !skipped) { 549 saveExpectedFile(j.file, j.format, out) 550 matched = true 551 } 552 553 switch { 554 case ok && matched: 555 // Test passed 556 return status{code: pass, timeTaken: timeTaken} 557 558 // --- Below this point the test has failed --- 559 560 case skipped: 561 if generateSkip { 562 saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) 563 } 564 return status{code: skip, timeTaken: timeTaken} 565 566 case !ok: 567 // Compiler returned non-zero exit code 568 if generateSkip { 569 saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) 570 } 571 err := fmt.Errorf("%s", out) 572 return status{code: fail, err: err, timeTaken: timeTaken} 573 574 default: 575 // Compiler returned zero exit code, or output was not as expected 576 if generateSkip { 577 saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) 578 } 579 580 // Expected output did not match 581 dmp := diffmatchpatch.New() 582 diff := dmp.DiffPrettyText(dmp.DiffMain(expected, out, true)) 583 err := fmt.Errorf(`Output was not as expected 584 585-------------------------------------------------------------------------------- 586-- Expected: -- 587-------------------------------------------------------------------------------- 588%s 589 590-------------------------------------------------------------------------------- 591-- Got: -- 592-------------------------------------------------------------------------------- 593%s 594 595-------------------------------------------------------------------------------- 596-- Diff: -- 597-------------------------------------------------------------------------------- 598%s`, 599 expected, out, diff) 600 return status{code: fail, err: err, timeTaken: timeTaken} 601 } 602 }() 603} 604 605// loadExpectedFile loads the expected output file for the test file at 'path' 606// and the output format 'format'. If the file does not exist, or cannot be 607// read, then an empty string is returned. 608func loadExpectedFile(path string, format outputFormat) string { 609 content, err := ioutil.ReadFile(expectedFilePath(path, format)) 610 if err != nil { 611 return "" 612 } 613 return string(content) 614} 615 616// saveExpectedFile writes the expected output file for the test file at 'path' 617// and the output format 'format', with the content 'content'. 618func saveExpectedFile(path string, format outputFormat, content string) error { 619 // Don't generate expected results for certain directories that contain 620 // large corpora of tests for which the generated code is uninteresting. 621 for _, exclude := range []string{"/test/unittest/", "/test/vk-gl-cts/"} { 622 if strings.Contains(path, filepath.FromSlash(exclude)) { 623 return nil 624 } 625 } 626 return ioutil.WriteFile(expectedFilePath(path, format), []byte(content), 0666) 627} 628 629// expectedFilePath returns the expected output file path for the test file at 630// 'path' and the output format 'format'. 631func expectedFilePath(path string, format outputFormat) string { 632 return path + ".expected." + string(format) 633} 634 635// indent returns the string 's' indented with 'n' whitespace characters 636func indent(s string, n int) string { 637 tab := strings.Repeat(" ", n) 638 return tab + strings.ReplaceAll(s, "\n", "\n"+tab) 639} 640 641// alignLeft returns the string of 'val' padded so that it is aligned left in 642// a column of the given width 643func alignLeft(val interface{}, width int) string { 644 s := fmt.Sprint(val) 645 padding := width - utf8.RuneCountInString(s) 646 if padding < 0 { 647 return s 648 } 649 return s + strings.Repeat(" ", padding) 650} 651 652// alignCenter returns the string of 'val' padded so that it is centered in a 653// column of the given width. 654func alignCenter(val interface{}, width int) string { 655 s := fmt.Sprint(val) 656 padding := width - utf8.RuneCountInString(s) 657 if padding < 0 { 658 return s 659 } 660 return strings.Repeat(" ", padding/2) + s + strings.Repeat(" ", (padding+1)/2) 661} 662 663// alignRight returns the string of 'val' padded so that it is aligned right in 664// a column of the given width 665func alignRight(val interface{}, width int) string { 666 s := fmt.Sprint(val) 667 padding := width - utf8.RuneCountInString(s) 668 if padding < 0 { 669 return s 670 } 671 return strings.Repeat(" ", padding) + s 672} 673 674// maxStringLen returns the maximum number of runes found in all the strings in 675// 'l' 676func maxStringLen(l []string) int { 677 max := 0 678 for _, s := range l { 679 if c := utf8.RuneCountInString(s); c > max { 680 max = c 681 } 682 } 683 return max 684} 685 686// formatWidth returns the width in runes for the outputFormat column 'b' 687func formatWidth(b outputFormat) int { 688 const min = 6 689 c := utf8.RuneCountInString(string(b)) 690 if c < min { 691 return min 692 } 693 return c 694} 695 696// percentage returns the percentage of n out of total as a string 697func percentage(n, total int) string { 698 if total == 0 { 699 return "-" 700 } 701 f := float64(n) / float64(total) 702 return fmt.Sprintf("%.1f%c", f*100.0, '%') 703} 704 705// invoke runs the executable 'exe' with the provided arguments. 706func invoke(wd, exe string, args ...string) (ok bool, output string) { 707 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 708 defer cancel() 709 710 cmd := exec.CommandContext(ctx, exe, args...) 711 cmd.Dir = wd 712 out, err := cmd.CombinedOutput() 713 str := string(out) 714 if err != nil { 715 if ctx.Err() == context.DeadlineExceeded { 716 return false, fmt.Sprintf("test timed out after %v", testTimeout) 717 } 718 if str != "" { 719 return false, str 720 } 721 return false, err.Error() 722 } 723 return true, str 724} 725 726var reFlags = regexp.MustCompile(` *\/\/ *flags:(.*)\n`) 727 728// parseFlags looks for a `// flags:` header at the start of the file with the 729// given path, returning each of the space delimited tokens that follow for the 730// line 731func parseFlags(path string) []string { 732 content, err := ioutil.ReadFile(path) 733 if err != nil { 734 return nil 735 } 736 header := strings.SplitN(string(content), "\n", 1)[0] 737 m := reFlags.FindStringSubmatch(header) 738 if len(m) != 2 { 739 return nil 740 } 741 return strings.Split(m[1], " ") 742} 743 744func printDuration(d time.Duration) string { 745 sec := int(d.Seconds()) 746 min := int(sec) / 60 747 hour := min / 60 748 min -= hour * 60 749 sec -= min * 60 750 sb := &strings.Builder{} 751 if hour > 0 { 752 fmt.Fprintf(sb, "%dh", hour) 753 } 754 if min > 0 { 755 fmt.Fprintf(sb, "%dm", min) 756 } 757 if sec > 0 { 758 fmt.Fprintf(sb, "%ds", sec) 759 } 760 return sb.String() 761} 762