• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17// Package internal generates Go source code with functions for TensorFlow operations.
18//
19// The basic outline of the generated API is as follows:
20//
21// - One function for each TensorFlow operation
22// - The arguments to the function are the inputs and required attributes of the operation
23// - The function returns the outputs
24// - A function is also generated for each optional attribute of the operation.
25//
26// There is a possibility that there are name collisions between the functions
27// generated for ops and the functions generated for optional attributes. For
28// now, we ignore those, but will need to revisit if a collision is actually
29// encountered.
30package internal
31
32/*
33#include <stdlib.h>
34
35#include "tensorflow/c/c_api.h"
36*/
37import "C"
38
39import (
40	"fmt"
41	"io"
42	"io/ioutil"
43	"path"
44	"reflect"
45	"strings"
46	"text/template"
47	"unsafe"
48
49	"google.golang.org/protobuf/encoding/prototext"
50	"google.golang.org/protobuf/proto"
51	adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto"
52	odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"
53)
54
55// GenerateFunctionsForRegisteredOps writes a Go source code file to w
56// containing functions for each TensorFlow operation registered in the address
57// space of the calling process.
58// apidefDirs should be a contain of directories containing api_def_*.pbtxt
59// files to load.
60func GenerateFunctionsForRegisteredOps(
61	w io.Writer, apidefDirs []string) error {
62	ops, apimap, err := registeredOps()
63	if err != nil {
64		return err
65	}
66	for _, dir := range apidefDirs {
67		if err = updateAPIDefs(apimap, dir); err != nil {
68			return err
69		}
70	}
71	return generateFunctionsForOps(w, ops, apimap)
72}
73
74func registeredOps() (*odpb.OpList, *apiDefMap, error) {
75	buf := C.TF_GetAllOpList()
76	defer C.TF_DeleteBuffer(buf)
77	var (
78		list = new(odpb.OpList)
79		size = int(buf.length)
80		// A []byte backed by C memory.
81		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
82		data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
83		err  = proto.Unmarshal(data, list)
84	)
85	if err != nil {
86		return nil, nil, err
87	}
88	apimap, err := newAPIDefMap(list)
89	return list, apimap, err
90}
91
92func updateAPIDefs(m *apiDefMap, dir string) error {
93	files, err := ioutil.ReadDir(dir)
94	if err != nil {
95		return err
96	}
97	for _, file := range files {
98		if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") {
99			continue
100		}
101		data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
102		if err != nil {
103			return fmt.Errorf("failed to read %q: %v", file.Name(), err)
104		}
105		if err = m.Put(string(data)); err != nil {
106			return fmt.Errorf("failed to process %q: %v", file.Name(), err)
107		}
108	}
109	return nil
110}
111
112func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error {
113	thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
114	if err := tmplHeader.Execute(w, thisPackage); err != nil {
115		return err
116	}
117	denylist := map[string]bool{
118		"Const":           true,
119		"PyFunc":          true,
120		"PyFuncStateless": true,
121	}
122	for _, op := range ops.Op {
123		if denylist[op.Name] {
124			continue
125		}
126		apidef, err := apimap.Get(op.Name)
127		if err != nil {
128			return err
129		}
130		if err := generateFunctionForOp(w, op, apidef); err != nil {
131			return err
132		}
133	}
134	return nil
135}
136
137func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error {
138	if strings.HasPrefix(op.Name, "_") { // Internal operation
139		return nil
140	}
141	// Ignore operations where the Go types corresponding to the TensorFlow
142	// type haven't been worked out (such as "func"s).
143	for _, a := range op.Attr {
144		if _, err := goType(a.Type); err != nil {
145			return nil
146		}
147	}
148	// Also, haven't figured out reference types yet, so ignore those too.
149	for _, a := range op.InputArg {
150		if a.IsRef {
151			return nil
152		}
153	}
154	for _, a := range op.OutputArg {
155		if a.IsRef {
156			return nil
157		}
158	}
159	if apidef.Summary == "" {
160		// Undocumented operation, perhaps a sign of not being ready to
161		// export.
162		return nil
163	}
164	tmplArgs, err := newTmplArgs(op, apidef)
165	if err != nil {
166		return err
167	}
168	return tmplOp.Execute(w, tmplArgs)
169}
170
171var (
172	// Go keywords that cannot be used as identifiers.
173	// From https://golang.org/ref/spec#Keywords
174	keywords = []string{
175		"break", "default", "func", "interface", "select", "case",
176		"defer", "go", "map", "struct", "chan", "else", "goto",
177		"package", "switch", "const", "fallthrough", "if", "range",
178		"type", "continue", "for", "import", "return", "var",
179	}
180
181	tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
182// This file was machine generated by {{.}}
183//
184// WARNING: This generation of wrapper function for TensorFlow ops is in an
185// experimental state. The generated API can change without notice.
186
187package op
188
189import tf "github.com/tensorflow/tensorflow/tensorflow/go"
190
191// optionalAttr is an intentionally un-exported type to hide
192// details of how optional attributes to operations are implemented.
193type optionalAttr map[string]interface{}
194
195func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
196	size, err := op.OutputListSize(output)
197	if err != nil {
198		return nil, start, err
199	}
200	list := make([]tf.Output, size)
201	for i := 0; i < size; i++ {
202		list[i] = op.Output(start + i)
203	}
204	return list, start + size, nil
205}
206`))
207
208	tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
209		"MakeComment":       makeComment,
210		"GoType":            goType,
211		"CamelCase":         camelCase,
212		"Identifier":        identifier,
213		"IsListArg":         isListArg,
214		"IsListAttr":        isListAttr,
215		"StripLeadingColon": stripLeadingColon,
216	}).Parse(`
217{{if .OptionalAttrs -}}
218{{/* Type for specifying all optional attributes. */ -}}
219// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
220type {{.Op.Name}}Attr func(optionalAttr)
221
222{{range .OptionalAttrs}}
223// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
224{{- if .Description}}
225//
226// value: {{MakeComment .Description}}
227{{- end}}
228// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
229{{- if .HasMinimum}}
230//
231// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
232{{- end}}
233func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
234	return func(m optionalAttr) {
235		m[{{printf "%q" .Name}}] = value
236	}
237}
238{{end}}
239{{end}}
240
241{{- /* Create a godoc friendly comment. */ -}}
242
243// {{MakeComment .APIDef.Summary}}
244
245{{- with .Op.Deprecation}}
246//
247// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
248{{- end -}}
249
250{{- with .APIDef.Description}}
251//
252// {{MakeComment .}}
253{{- end -}}
254
255{{- if .DescribeArguments}}
256//
257// Arguments:
258{{- range .InArgsReordered}}
259//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
260{{- end -}}
261{{- range .RequiredAttrs}}
262//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
263{{- end -}}
264{{- end -}}
265
266{{- if (not .Op.OutputArg) }}
267//
268// Returns the created operation.
269{{- else }}
270{{- if .DescribeOutputs}}
271//
272{{- if eq (len .OutArgs) 1 }}
273// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
274{{- else }}
275// Returns:
276{{- range .OutArgs}}
277//	{{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
278{{- end -}}
279{{- end -}}
280{{- end -}}
281{{- end -}}
282{{- /*
283
284  The function signature.
285  Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
286*/}}
287func {{.Op.Name}}
288
289{{- /*
290  Fill in input arguments:
291  (1) The Scope
292  (2) All input arguments (which may be either []tf.Output or tf.Output)
293  (3) All required attributes
294  (4) Variadic list of optional attributes
295*/ -}}
296
297(scope *Scope
298{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
299{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
300{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
301)
302
303{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
304
305{{if .OutArgs -}}
306({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
307{{- else -}}
308(o *tf.Operation)
309{{- end }} {
310	if scope.Err() != nil {
311		return
312	}
313	{{if .HasAttrs -}}
314	attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
315	{{if .OptionalAttrs -}}
316	for _, a := range optional {
317		a(attrs)
318	}
319	{{end -}}
320	{{end -}}
321	opspec := tf.OpSpec{
322		Type: {{printf "%q" .Op.Name}},
323		{{if .InArgs -}}
324		Input: []tf.Input{
325			{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
326		},
327		{{- end}}
328		{{- if .HasAttrs}}
329		Attrs: attrs,
330		{{- end}}
331	}
332	{{- if .OutArgs}}
333	{{- if .HasListOutput}}
334	op := scope.AddOperation(opspec)
335	if scope.Err() != nil {
336		return
337	}
338	var idx int
339	var err error
340	{{- range $i, $a := .OutArgs}}
341	{{- if $a.IsListArg}}
342	if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
343		scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
344		return
345	}
346	{{- else }}
347	{{Identifier .RenameTo}} = op.Output(idx)
348	{{- end }}{{- /* if IsListArg */}}
349	{{- end }}{{- /* range .OutArgs */}}
350	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
351	{{- else }}
352	op := scope.AddOperation(opspec)
353	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
354	{{- end }}{{- /* if .HasListOutput */}}
355	{{- else }}
356	return scope.AddOperation(opspec)
357	{{- end }}{{- /* if .OutArgs */}}
358}
359`))
360)
361
362type attrWrapper struct {
363	op  *odpb.OpDef_AttrDef
364	api *adpb.ApiDef_Attr
365}
366
367func (a *attrWrapper) Name() string              { return a.api.Name }
368func (a *attrWrapper) RenameTo() string          { return a.api.RenameTo }
369func (a *attrWrapper) Description() string       { return a.api.Description }
370func (a *attrWrapper) Type() string              { return a.op.Type }
371func (a *attrWrapper) IsListAttr() bool          { return isListAttr(a.op) }
372func (a *attrWrapper) HasMinimum() bool          { return a.op.HasMinimum }
373func (a *attrWrapper) Minimum() int64            { return a.op.Minimum }
374func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
375
376type argWrapper struct {
377	op  *odpb.OpDef_ArgDef
378	api *adpb.ApiDef_Arg
379}
380
381func (a *argWrapper) Name() string        { return a.api.Name }
382func (a *argWrapper) RenameTo() string    { return a.api.RenameTo }
383func (a *argWrapper) Description() string { return a.api.Description }
384func (a *argWrapper) IsListArg() bool     { return isListArg(a.op) }
385
386type tmplArgs struct {
387	Op     *odpb.OpDef
388	APIDef *adpb.ApiDef
389	// Op.Attr is split into two categories
390	// (1) Required: These must be specified by the client and are thus
391	//     included in the function signature.
392	// (2) Optional: These need not be specified (as they have default
393	//     values) and thus do not appear in the function signature.
394	RequiredAttrs []*attrWrapper
395	OptionalAttrs []*attrWrapper
396	InArgs        []*argWrapper
397	// Input arguments ordered based on arg_order field of ApiDef.
398	InArgsReordered []*argWrapper
399	OutArgs         []*argWrapper
400}
401
402func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) {
403	ret := tmplArgs{Op: op, APIDef: apidef}
404
405	// Setup InArgs field
406	for i, in := range op.InputArg {
407		argCombined := argWrapper{op: in, api: apidef.InArg[i]}
408		ret.InArgs = append(ret.InArgs, &argCombined)
409	}
410
411	// Setup OutArgs field
412	for i, out := range op.OutputArg {
413		argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
414		ret.OutArgs = append(ret.OutArgs, &argCombined)
415	}
416
417	// Setup InArgsReordered field
418	for _, argName := range apidef.ArgOrder {
419		// Find the argument in op.InputArg
420		argIndex := -1
421		for i, in := range op.InputArg {
422			if in.Name == argName {
423				argIndex = i
424				break
425			}
426		}
427		if argIndex == -1 {
428			return nil, fmt.Errorf(
429				"couldn't find argument %s in ApiDef for op %s",
430				argName, op.Name)
431		}
432		argCombined := argWrapper{
433			op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
434		ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
435	}
436
437	if len(op.Attr) == 0 {
438		return &ret, nil
439	}
440	// Attributes related to the InputArg's type are inferred automatically
441	// and are not exposed to the client.
442	inferred := make(map[string]bool)
443	for _, in := range op.InputArg {
444		switch {
445		case in.TypeAttr != "":
446			inferred[in.TypeAttr] = true
447		case in.TypeListAttr != "":
448			inferred[in.TypeListAttr] = true
449		}
450		if in.NumberAttr != "" {
451			inferred[in.NumberAttr] = true
452		}
453	}
454	for i, attr := range op.Attr {
455		if inferred[attr.Name] {
456			continue
457		}
458		attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
459		if attr.DefaultValue == nil {
460			ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
461		} else {
462			ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
463		}
464	}
465	return &ret, nil
466}
467
468func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
469func (a *tmplArgs) DescribeArguments() bool {
470	for _, arg := range a.InArgs {
471		if arg.Description() != "" {
472			return true
473		}
474	}
475	for _, attr := range a.RequiredAttrs {
476		if attr.Description() != "" {
477			return true
478		}
479	}
480	return false
481
482}
483func (a *tmplArgs) DescribeOutputs() bool {
484	for _, arg := range a.OutArgs {
485		if arg.Description() != "" {
486			return true
487		}
488	}
489	return false
490}
491func (a *tmplArgs) HasListOutput() bool {
492	for _, arg := range a.OutArgs {
493		if arg.IsListArg() {
494			return true
495		}
496	}
497	return false
498}
499
500func makeComment(lines string) string {
501	return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
502}
503
504// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
505// to the corresponding type in Go.
506func goType(tfType string) (string, error) {
507	list, tfType := parseTFType(tfType)
508	var gotype string
509	switch tfType {
510	case "int":
511		gotype = "int64"
512	case "float":
513		gotype = "float32"
514	case "bool":
515		gotype = "bool"
516	case "type":
517		gotype = "tf.DataType"
518	case "shape":
519		gotype = "tf.Shape"
520	case "tensor":
521		gotype = "tf.Tensor"
522	case "string":
523		gotype = "string"
524	default:
525		return "", fmt.Errorf("%q is not a recognized DataType", tfType)
526	}
527	if list {
528		gotype = "[]" + gotype
529	}
530	return gotype, nil
531}
532
533func camelCase(snakeCase string) string {
534	words := strings.Split(snakeCase, "_")
535	for i, w := range words {
536		words[i] = strings.ToUpper(string(w[0])) + w[1:]
537	}
538	return strings.Join(words, "")
539}
540
541// identifier creates an identifier for s usable in the generated Go source
542// code.
543//
544// Avoids collisions with keywords and other identifiers used in the generated
545// code.
546func identifier(s string) string {
547	// Identifiers used in the generated code.
548	if s == "tf" || s == "scope" || s == "err" || s == "op" {
549		return s + "_"
550	}
551	for _, k := range keywords {
552		if s == k {
553			// Alternatively, make the first letter upper case.
554			return s + "_"
555		}
556	}
557	return s
558}
559
560func isListArg(argdef *odpb.OpDef_ArgDef) bool {
561	return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
562}
563
564func isListAttr(attrdef *odpb.OpDef_AttrDef) bool {
565	list, _ := parseTFType(attrdef.Type)
566	return list
567}
568
569// stripLeadingColon removes the prefix of the string up to the first colon.
570//
571// This is useful when 's' corresponds to a "oneof" protocol buffer message.
572// For example, consider the protocol buffer message:
573//   oneof value { bool b = 1;  int64 i = 2; }
574// proto.CompactTextString) will print "b:true", or "i:7" etc. This function
575// strips out the leading "b:" or "i:".
576func stripLeadingColon(m proto.Message) string {
577	o := prototext.MarshalOptions{Multiline: false}
578	x := o.Format(m)
579	y := strings.SplitN(x, ":", 2)
580	if len(y) < 2 {
581		return x
582	}
583	return y[1]
584}
585
586func parseTFType(tfType string) (list bool, typ string) {
587	const (
588		listPrefix = "list("
589		listSuffix = ")"
590	)
591	if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
592		return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
593	}
594	return false, tfType
595}
596