• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18	"bytes"
19	"flag"
20	"fmt"
21	"go/ast"
22	"go/doc"
23	"go/parser"
24	"go/token"
25	"io/ioutil"
26	"os"
27	"reflect"
28	"sort"
29	"strings"
30	"testing"
31	"text/template"
32)
33
34var (
35	output   = flag.String("o", "", "output filename")
36	pkg      = flag.String("pkg", "", "test package")
37	exitCode = 0
38)
39
40type data struct {
41	Package                 string
42	Tests                   []string
43	Examples                []*doc.Example
44	HasMain                 bool
45	MainStartTakesInterface bool
46}
47
48func findTests(srcs []string) (tests []string, examples []*doc.Example, hasMain bool) {
49	for _, src := range srcs {
50		f, err := parser.ParseFile(token.NewFileSet(), src, nil, parser.ParseComments)
51		if err != nil {
52			panic(err)
53		}
54		for _, obj := range f.Scope.Objects {
55			if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") {
56				continue
57			}
58			if obj.Name == "TestMain" {
59				hasMain = true
60			} else {
61				tests = append(tests, obj.Name)
62			}
63		}
64
65		examples = append(examples, doc.Examples(f)...)
66	}
67	sort.Strings(tests)
68	return
69}
70
71// Returns true for go1.8+, where testing.MainStart takes an interface instead of a function
72// as its first argument.
73func mainStartTakesInterface() bool {
74	return reflect.TypeOf(testing.MainStart).In(0).Kind() == reflect.Interface
75}
76
77func main() {
78	flag.Parse()
79
80	if flag.NArg() == 0 {
81		fmt.Fprintln(os.Stderr, "error: must pass at least one input")
82		exitCode = 1
83		return
84	}
85
86	buf := &bytes.Buffer{}
87
88	tests, examples, hasMain := findTests(flag.Args())
89
90	d := data{
91		Package:                 *pkg,
92		Tests:                   tests,
93		Examples:                examples,
94		HasMain:                 hasMain,
95		MainStartTakesInterface: mainStartTakesInterface(),
96	}
97
98	err := testMainTmpl.Execute(buf, d)
99	if err != nil {
100		panic(err)
101	}
102
103	err = ioutil.WriteFile(*output, buf.Bytes(), 0666)
104	if err != nil {
105		panic(err)
106	}
107}
108
109var testMainTmpl = template.Must(template.New("testMain").Parse(`
110package main
111
112import (
113	"io"
114{{if not .HasMain}}
115	"os"
116{{end}}
117	"regexp"
118	"testing"
119
120	pkg "{{.Package}}"
121)
122
123var t = []testing.InternalTest{
124{{range .Tests}}
125	{"{{.}}", pkg.{{.}}},
126{{end}}
127}
128
129var e = []testing.InternalExample{
130{{range .Examples}}
131	{{if or .Output .EmptyOutput}}
132		{"{{.Name}}", pkg.Example{{.Name}}, {{.Output | printf "%q" }}, {{.Unordered}}},
133	{{end}}
134{{end}}
135}
136
137var matchPat string
138var matchRe *regexp.Regexp
139
140type matchString struct{}
141
142func MatchString(pat, str string) (result bool, err error) {
143	if matchRe == nil || matchPat != pat {
144		matchPat = pat
145		matchRe, err = regexp.Compile(matchPat)
146		if err != nil {
147			return
148		}
149	}
150	return matchRe.MatchString(str), nil
151}
152
153func (matchString) MatchString(pat, str string) (bool, error) {
154	return MatchString(pat, str)
155}
156
157func (matchString) StartCPUProfile(w io.Writer) error {
158	panic("shouldn't get here")
159}
160
161func (matchString) StopCPUProfile() {
162}
163
164func (matchString) WriteHeapProfile(w io.Writer) error {
165    panic("shouldn't get here")
166}
167
168func (matchString) WriteProfileTo(string, io.Writer, int) error {
169    panic("shouldn't get here")
170}
171
172func (matchString) ImportPath() string {
173	return "{{.Package}}"
174}
175
176func (matchString) StartTestLog(io.Writer) {
177	panic("shouldn't get here")
178}
179
180func (matchString) StopTestLog() error {
181	panic("shouldn't get here")
182}
183
184func main() {
185{{if .MainStartTakesInterface}}
186	m := testing.MainStart(matchString{}, t, nil, e)
187{{else}}
188	m := testing.MainStart(MatchString, t, nil, e)
189{{end}}
190{{if .HasMain}}
191	pkg.TestMain(m)
192{{else}}
193	os.Exit(m.Run())
194{{end}}
195}
196`))
197