• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2017 The Wuffs Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package generate
16
17import (
18	"flag"
19	"fmt"
20	"io/ioutil"
21	"os"
22	"path/filepath"
23	"strings"
24
25	"github.com/google/wuffs/lang/check"
26	"github.com/google/wuffs/lang/parse"
27	"github.com/google/wuffs/lang/wuffsroot"
28
29	a "github.com/google/wuffs/lang/ast"
30	t "github.com/google/wuffs/lang/token"
31)
32
33type Generator func(packageName string, tm *t.Map, c *check.Checker, files []*a.File) ([]byte, error)
34
35func Do(flags *flag.FlagSet, args []string, g Generator) error {
36	packageName := flags.String("package_name", "", "the package name of the Wuffs input code")
37	if err := flags.Parse(args); err != nil {
38		return err
39	}
40	out := []byte(nil)
41
42	if *packageName == "base" && len(flags.Args()) == 0 {
43		var err error
44		out, err = g("base", nil, nil, nil)
45		if err != nil {
46			return err
47		}
48
49	} else {
50		pkgName := checkPackageName(*packageName)
51		if pkgName == "" {
52			return fmt.Errorf("prohibited package name %q", *packageName)
53		}
54
55		tm := &t.Map{}
56		files, err := parseFiles(tm, flags.Args())
57		if err != nil {
58			return err
59		}
60
61		c, err := check.Check(tm, files, resolveUse)
62		if err != nil {
63			return err
64		}
65
66		out, err = g(pkgName, tm, c, files)
67		if err != nil {
68			return err
69		}
70	}
71
72	_, err := os.Stdout.Write(out)
73	return err
74}
75
76func checkPackageName(s string) string {
77	allUnderscores := true
78	for i := 0; i < len(s); i++ {
79		c := s[i]
80		if ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z') || (c == '_') || ('0' <= c && c <= '9') {
81			allUnderscores = allUnderscores && c == '_'
82		} else {
83			return ""
84		}
85	}
86	if allUnderscores {
87		return ""
88	}
89	s = strings.ToLower(s)
90	// Blacklist certain package names.
91	switch s {
92	case "base", "base_private", "base_private_header", "base_public", "base_public_header",
93		"config", "implementation", "include_guard", "initialize", "version":
94		return ""
95	}
96	return s
97}
98
99func parseFiles(tm *t.Map, filenames []string) (files []*a.File, err error) {
100	if len(filenames) == 0 {
101		const filename = "stdin"
102		src, err := ioutil.ReadAll(os.Stdin)
103		if err != nil {
104			return nil, err
105		}
106		tokens, _, err := t.Tokenize(tm, filename, src)
107		if err != nil {
108			return nil, err
109		}
110		f, err := parse.Parse(tm, filename, tokens, nil)
111		if err != nil {
112			return nil, err
113		}
114		return []*a.File{f}, nil
115	}
116	return ParseFiles(tm, filenames, nil)
117}
118
119func ParseFiles(tm *t.Map, filenames []string, opts *parse.Options) (files []*a.File, err error) {
120	for _, filename := range filenames {
121		src, err := ioutil.ReadFile(filename)
122		if err != nil {
123			return nil, err
124		}
125		tokens, _, err := t.Tokenize(tm, filename, src)
126		if err != nil {
127			return nil, err
128		}
129		f, err := parse.Parse(tm, filename, tokens, opts)
130		if err != nil {
131			return nil, err
132		}
133		files = append(files, f)
134	}
135	return files, nil
136}
137
138func resolveUse(usePath string) ([]byte, error) {
139	wuffsRoot, err := wuffsroot.Value()
140	if err != nil {
141		return nil, err
142	}
143	return ioutil.ReadFile(filepath.Join(wuffsRoot, "gen", "wuffs", filepath.FromSlash(usePath)))
144}
145