1// Copyright 2015 Google Inc. All rights reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17import ( 18 "bytes" 19 "flag" 20 "fmt" 21 "go/ast" 22 "go/doc" 23 "go/parser" 24 "go/token" 25 "io/ioutil" 26 "os" 27 "reflect" 28 "sort" 29 "strings" 30 "testing" 31 "text/template" 32) 33 34var ( 35 output = flag.String("o", "", "output filename") 36 pkg = flag.String("pkg", "", "test package") 37 exitCode = 0 38) 39 40type data struct { 41 Package string 42 Tests []string 43 Examples []*doc.Example 44 HasMain bool 45 MainStartTakesInterface bool 46} 47 48func findTests(srcs []string) (tests []string, examples []*doc.Example, hasMain bool) { 49 for _, src := range srcs { 50 f, err := parser.ParseFile(token.NewFileSet(), src, nil, parser.ParseComments) 51 if err != nil { 52 panic(err) 53 } 54 for _, obj := range f.Scope.Objects { 55 if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") { 56 continue 57 } 58 if obj.Name == "TestMain" { 59 hasMain = true 60 } else { 61 tests = append(tests, obj.Name) 62 } 63 } 64 65 examples = append(examples, doc.Examples(f)...) 66 } 67 sort.Strings(tests) 68 return 69} 70 71// Returns true for go1.8+, where testing.MainStart takes an interface instead of a function 72// as its first argument. 73func mainStartTakesInterface() bool { 74 return reflect.TypeOf(testing.MainStart).In(0).Kind() == reflect.Interface 75} 76 77func main() { 78 flag.Parse() 79 80 if flag.NArg() == 0 { 81 fmt.Fprintln(os.Stderr, "error: must pass at least one input") 82 exitCode = 1 83 return 84 } 85 86 buf := &bytes.Buffer{} 87 88 tests, examples, hasMain := findTests(flag.Args()) 89 90 d := data{ 91 Package: *pkg, 92 Tests: tests, 93 Examples: examples, 94 HasMain: hasMain, 95 MainStartTakesInterface: mainStartTakesInterface(), 96 } 97 98 err := testMainTmpl.Execute(buf, d) 99 if err != nil { 100 panic(err) 101 } 102 103 err = ioutil.WriteFile(*output, buf.Bytes(), 0666) 104 if err != nil { 105 panic(err) 106 } 107} 108 109var testMainTmpl = template.Must(template.New("testMain").Parse(` 110package main 111 112import ( 113 "io" 114{{if not .HasMain}} 115 "os" 116{{end}} 117 "regexp" 118 "testing" 119 120 pkg "{{.Package}}" 121) 122 123var t = []testing.InternalTest{ 124{{range .Tests}} 125 {"{{.}}", pkg.{{.}}}, 126{{end}} 127} 128 129var e = []testing.InternalExample{ 130{{range .Examples}} 131 {{if or .Output .EmptyOutput}} 132 {"{{.Name}}", pkg.Example{{.Name}}, {{.Output | printf "%q" }}, {{.Unordered}}}, 133 {{end}} 134{{end}} 135} 136 137var matchPat string 138var matchRe *regexp.Regexp 139 140type matchString struct{} 141 142func MatchString(pat, str string) (result bool, err error) { 143 if matchRe == nil || matchPat != pat { 144 matchPat = pat 145 matchRe, err = regexp.Compile(matchPat) 146 if err != nil { 147 return 148 } 149 } 150 return matchRe.MatchString(str), nil 151} 152 153func (matchString) MatchString(pat, str string) (bool, error) { 154 return MatchString(pat, str) 155} 156 157func (matchString) StartCPUProfile(w io.Writer) error { 158 panic("shouldn't get here") 159} 160 161func (matchString) StopCPUProfile() { 162} 163 164func (matchString) WriteHeapProfile(w io.Writer) error { 165 panic("shouldn't get here") 166} 167 168func (matchString) WriteProfileTo(string, io.Writer, int) error { 169 panic("shouldn't get here") 170} 171 172func (matchString) ImportPath() string { 173 return "{{.Package}}" 174} 175 176func (matchString) StartTestLog(io.Writer) { 177 panic("shouldn't get here") 178} 179 180func (matchString) StopTestLog() error { 181 panic("shouldn't get here") 182} 183 184func main() { 185{{if .MainStartTakesInterface}} 186 m := testing.MainStart(matchString{}, t, nil, e) 187{{else}} 188 m := testing.MainStart(MatchString, t, nil, e) 189{{end}} 190{{if .HasMain}} 191 pkg.TestMain(m) 192{{else}} 193 os.Exit(m.Run()) 194{{end}} 195} 196`)) 197