1// Copyright 2017 The Wuffs Authors. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package generate 16 17import ( 18 "flag" 19 "fmt" 20 "io/ioutil" 21 "os" 22 "path/filepath" 23 "strings" 24 25 "github.com/google/wuffs/lang/check" 26 "github.com/google/wuffs/lang/parse" 27 "github.com/google/wuffs/lang/wuffsroot" 28 29 a "github.com/google/wuffs/lang/ast" 30 t "github.com/google/wuffs/lang/token" 31) 32 33type Generator func(packageName string, tm *t.Map, c *check.Checker, files []*a.File) ([]byte, error) 34 35func Do(flags *flag.FlagSet, args []string, g Generator) error { 36 packageName := flags.String("package_name", "", "the package name of the Wuffs input code") 37 if err := flags.Parse(args); err != nil { 38 return err 39 } 40 out := []byte(nil) 41 42 if *packageName == "base" && len(flags.Args()) == 0 { 43 var err error 44 out, err = g("base", nil, nil, nil) 45 if err != nil { 46 return err 47 } 48 49 } else { 50 pkgName := checkPackageName(*packageName) 51 if pkgName == "" { 52 return fmt.Errorf("prohibited package name %q", *packageName) 53 } 54 55 tm := &t.Map{} 56 files, err := parseFiles(tm, flags.Args()) 57 if err != nil { 58 return err 59 } 60 61 c, err := check.Check(tm, files, resolveUse) 62 if err != nil { 63 return err 64 } 65 66 out, err = g(pkgName, tm, c, files) 67 if err != nil { 68 return err 69 } 70 } 71 72 _, err := os.Stdout.Write(out) 73 return err 74} 75 76func checkPackageName(s string) string { 77 allUnderscores := true 78 for i := 0; i < len(s); i++ { 79 c := s[i] 80 if ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z') || (c == '_') || ('0' <= c && c <= '9') { 81 allUnderscores = allUnderscores && c == '_' 82 } else { 83 return "" 84 } 85 } 86 if allUnderscores { 87 return "" 88 } 89 s = strings.ToLower(s) 90 // Blacklist certain package names. 91 switch s { 92 case "base", "base_private", "base_private_header", "base_public", "base_public_header", 93 "config", "implementation", "include_guard", "initialize", "version": 94 return "" 95 } 96 return s 97} 98 99func parseFiles(tm *t.Map, filenames []string) (files []*a.File, err error) { 100 if len(filenames) == 0 { 101 const filename = "stdin" 102 src, err := ioutil.ReadAll(os.Stdin) 103 if err != nil { 104 return nil, err 105 } 106 tokens, _, err := t.Tokenize(tm, filename, src) 107 if err != nil { 108 return nil, err 109 } 110 f, err := parse.Parse(tm, filename, tokens, nil) 111 if err != nil { 112 return nil, err 113 } 114 return []*a.File{f}, nil 115 } 116 return ParseFiles(tm, filenames, nil) 117} 118 119func ParseFiles(tm *t.Map, filenames []string, opts *parse.Options) (files []*a.File, err error) { 120 for _, filename := range filenames { 121 src, err := ioutil.ReadFile(filename) 122 if err != nil { 123 return nil, err 124 } 125 tokens, _, err := t.Tokenize(tm, filename, src) 126 if err != nil { 127 return nil, err 128 } 129 f, err := parse.Parse(tm, filename, tokens, opts) 130 if err != nil { 131 return nil, err 132 } 133 files = append(files, f) 134 } 135 return files, nil 136} 137 138func resolveUse(usePath string) ([]byte, error) { 139 wuffsRoot, err := wuffsroot.Value() 140 if err != nil { 141 return nil, err 142 } 143 return ioutil.ReadFile(filepath.Join(wuffsRoot, "gen", "wuffs", filepath.FromSlash(usePath))) 144} 145