• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package gotestmain
16
17import (
18	"bytes"
19	"flag"
20	"fmt"
21	"go/ast"
22	"go/parser"
23	"go/token"
24	"io/ioutil"
25	"os"
26	"strings"
27	"text/template"
28)
29
30var (
31	output   = flag.String("o", "", "output filename")
32	pkg      = flag.String("pkg", "", "test package")
33	exitCode = 0
34)
35
36type data struct {
37	Package string
38	Tests   []string
39}
40
41func findTests(srcs []string) (tests []string) {
42	for _, src := range srcs {
43		f, err := parser.ParseFile(token.NewFileSet(), src, nil, 0)
44		if err != nil {
45			panic(err)
46		}
47		for _, obj := range f.Scope.Objects {
48			if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") {
49				continue
50			}
51			tests = append(tests, obj.Name)
52		}
53	}
54	return
55}
56
57func main() {
58	flag.Parse()
59
60	if flag.NArg() == 0 {
61		fmt.Fprintln(os.Stderr, "error: must pass at least one input")
62		exitCode = 1
63		return
64	}
65
66	buf := &bytes.Buffer{}
67
68	d := data{
69		Package: *pkg,
70		Tests:   findTests(flag.Args()),
71	}
72
73	err := testMainTmpl.Execute(buf, d)
74	if err != nil {
75		panic(err)
76	}
77
78	err = ioutil.WriteFile(*output, buf.Bytes(), 0666)
79	if err != nil {
80		panic(err)
81	}
82}
83
84var testMainTmpl = template.Must(template.New("testMain").Parse(`
85package main
86
87import (
88	"testing"
89
90	pkg "{{.Package}}"
91)
92
93var t = []testing.InternalTest{
94{{range .Tests}}
95	{"{{.}}", pkg.{{.}}},
96{{end}}
97}
98
99func matchString(pat, str string) (bool, error) {
100	return true, nil
101}
102
103func main() {
104	testing.Main(matchString, t, nil, nil)
105}
106`))
107