• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17// Package internal generates Go source code with functions for TensorFlow operations.
18//
19// The basic outline of the generated API is as follows:
20//
21// - One function for each TensorFlow operation
22// - The arguments to the function are the inputs and required attributes of the operation
23// - The function returns the outputs
24// - A function is also generated for each optional attribute of the operation.
25//
26// There is a possibility that there are name collisions between the functions
27// generated for ops and the functions generated for optional attributes. For
28// now, we ignore those, but will need to revisit if a collision is actually
29// encountered.
30package internal
31
32/*
33#include <stdlib.h>
34
35#include "tensorflow/c/c_api.h"
36*/
37import "C"
38
39import (
40	"fmt"
41	"io"
42	"io/ioutil"
43	"path"
44	"reflect"
45	"sort"
46	"strings"
47	"text/template"
48	"unsafe"
49
50	"google.golang.org/protobuf/encoding/prototext"
51	"google.golang.org/protobuf/proto"
52
53	adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto"
54	odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"
55)
56
57// GenerateFunctionsForRegisteredOps writes a Go source code file to w
58// containing functions for each TensorFlow operation registered in the address
59// space of the calling process.
60// apidefDirs should be a contain of directories containing api_def_*.pbtxt
61// files to load.
62func GenerateFunctionsForRegisteredOps(
63	w io.Writer, apidefDirs []string) error {
64	ops, apimap, err := registeredOps()
65	if err != nil {
66		return err
67	}
68	for _, dir := range apidefDirs {
69		if err = updateAPIDefs(apimap, dir); err != nil {
70			return err
71		}
72	}
73	return generateFunctionsForOps(w, ops, apimap)
74}
75
76func registeredOps() (*odpb.OpList, *apiDefMap, error) {
77	buf := C.TF_GetAllOpList()
78	defer C.TF_DeleteBuffer(buf)
79	var (
80		list = new(odpb.OpList)
81		size = int(buf.length)
82		// A []byte backed by C memory.
83		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
84		data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
85		err  = proto.Unmarshal(data, list)
86	)
87	if err != nil {
88		return nil, nil, err
89	}
90	// Sort ops by name
91	sort.Slice(list.Op, func(i, j int) bool {
92		return list.Op[i].Name < list.Op[j].Name
93	})
94	apimap, err := newAPIDefMap(list)
95	return list, apimap, err
96}
97
98func updateAPIDefs(m *apiDefMap, dir string) error {
99	files, err := ioutil.ReadDir(dir)
100	if err != nil {
101		return err
102	}
103	for _, file := range files {
104		if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") {
105			continue
106		}
107		data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
108		if err != nil {
109			return fmt.Errorf("failed to read %q: %v", file.Name(), err)
110		}
111		if err = m.Put(string(data)); err != nil {
112			return fmt.Errorf("failed to process %q: %v", file.Name(), err)
113		}
114	}
115	return nil
116}
117
118func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error {
119	thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
120	if err := tmplHeader.Execute(w, thisPackage); err != nil {
121		return err
122	}
123	denylist := map[string]bool{
124		"Const":           true,
125		"PyFunc":          true,
126		"PyFuncStateless": true,
127	}
128	for _, op := range ops.Op {
129		if denylist[op.Name] {
130			continue
131		}
132		apidef, err := apimap.Get(op.Name)
133		if err != nil {
134			return err
135		}
136		if err := generateFunctionForOp(w, op, apidef); err != nil {
137			return err
138		}
139	}
140	return nil
141}
142
143func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error {
144	if strings.HasPrefix(op.Name, "_") { // Internal operation
145		return nil
146	}
147	// Ignore operations where the Go types corresponding to the TensorFlow
148	// type haven't been worked out (such as "func"s).
149	for _, a := range op.Attr {
150		if _, err := goType(a.Type); err != nil {
151			return nil
152		}
153	}
154	// Also, haven't figured out reference types yet, so ignore those too.
155	for _, a := range op.InputArg {
156		if a.IsRef {
157			return nil
158		}
159	}
160	for _, a := range op.OutputArg {
161		if a.IsRef {
162			return nil
163		}
164	}
165	if apidef.Summary == "" {
166		// Undocumented operation, perhaps a sign of not being ready to
167		// export.
168		return nil
169	}
170	tmplArgs, err := newTmplArgs(op, apidef)
171	if err != nil {
172		return err
173	}
174	return tmplOp.Execute(w, tmplArgs)
175}
176
177var (
178	// Go keywords that cannot be used as identifiers.
179	// From https://golang.org/ref/spec#Keywords
180	keywords = []string{
181		"break", "default", "func", "interface", "select", "case",
182		"defer", "go", "map", "struct", "chan", "else", "goto",
183		"package", "switch", "const", "fallthrough", "if", "range",
184		"type", "continue", "for", "import", "return", "var",
185	}
186
187	tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
188// This file was machine generated by {{.}}
189//
190// WARNING: This generation of wrapper function for TensorFlow ops is in an
191// experimental state. The generated API can change without notice.
192
193package op
194
195import tf "github.com/tensorflow/tensorflow/tensorflow/go"
196
197// optionalAttr is an intentionally un-exported type to hide
198// details of how optional attributes to operations are implemented.
199type optionalAttr map[string]interface{}
200
201func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
202	size, err := op.OutputListSize(output)
203	if err != nil {
204		return nil, start, err
205	}
206	list := make([]tf.Output, size)
207	for i := 0; i < size; i++ {
208		list[i] = op.Output(start + i)
209	}
210	return list, start + size, nil
211}
212`))
213
214	tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
215		"MakeComment":         makeComment,
216		"GoType":              goType,
217		"CamelCase":           camelCase,
218		"Identifier":          identifier,
219		"IsListArg":           isListArg,
220		"IsListAttr":          isListAttr,
221		"MarshalProtoMessage": marshalProtoMessage,
222	}).Parse(`
223{{if .OptionalAttrs -}}
224{{/* Type for specifying all optional attributes. */ -}}
225// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
226type {{.Op.Name}}Attr func(optionalAttr)
227
228{{range .OptionalAttrs}}
229// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
230{{- if .Description}}
231//
232// value: {{MakeComment .Description}}
233{{- end}}
234// If not specified, defaults to {{MarshalProtoMessage .DefaultValue}}
235{{- if .HasMinimum}}
236//
237// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
238{{- end}}
239func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
240	return func(m optionalAttr) {
241		m[{{printf "%q" .Name}}] = value
242	}
243}
244{{end}}
245{{end}}
246
247{{- /* Create a godoc friendly comment. */ -}}
248
249// {{MakeComment .APIDef.Summary}}
250
251{{- with .Op.Deprecation}}
252//
253// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
254{{- end -}}
255
256{{- with .APIDef.Description}}
257//
258// {{MakeComment .}}
259{{- end -}}
260
261{{- if .DescribeArguments}}
262//
263// Arguments:
264{{- range .InArgsReordered}}
265//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
266{{- end -}}
267{{- range .RequiredAttrs}}
268//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
269{{- end -}}
270{{- end -}}
271
272{{- if (not .Op.OutputArg) }}
273//
274// Returns the created operation.
275{{- else }}
276{{- if .DescribeOutputs}}
277//
278{{- if eq (len .OutArgs) 1 }}
279// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
280{{- else }}
281// Returns:
282{{- range .OutArgs}}
283//	{{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
284{{- end -}}
285{{- end -}}
286{{- end -}}
287{{- end -}}
288{{- /*
289
290  The function signature.
291  Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
292*/}}
293func {{.Op.Name}}
294
295{{- /*
296  Fill in input arguments:
297  (1) The Scope
298  (2) All input arguments (which may be either []tf.Output or tf.Output)
299  (3) All required attributes
300  (4) Variadic list of optional attributes
301*/ -}}
302
303(scope *Scope
304{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
305{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
306{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
307)
308
309{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
310
311{{if .OutArgs -}}
312({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
313{{- else -}}
314(o *tf.Operation)
315{{- end }} {
316	if scope.Err() != nil {
317		return
318	}
319	{{if .HasAttrs -}}
320	attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
321	{{if .OptionalAttrs -}}
322	for _, a := range optional {
323		a(attrs)
324	}
325	{{end -}}
326	{{end -}}
327	opspec := tf.OpSpec{
328		Type: {{printf "%q" .Op.Name}},
329		{{if .InArgs -}}
330		Input: []tf.Input{
331			{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
332		},
333		{{- end}}
334		{{- if .HasAttrs}}
335		Attrs: attrs,
336		{{- end}}
337	}
338	{{- if .OutArgs}}
339	{{- if .HasListOutput}}
340	op := scope.AddOperation(opspec)
341	if scope.Err() != nil {
342		return
343	}
344	var idx int
345	var err error
346	{{- range $i, $a := .OutArgs}}
347	{{- if $a.IsListArg}}
348	if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
349		scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
350		return
351	}
352	{{- else }}
353	{{Identifier .RenameTo}} = op.Output(idx)
354	{{- end }}{{- /* if IsListArg */}}
355	{{- end }}{{- /* range .OutArgs */}}
356	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
357	{{- else }}
358	op := scope.AddOperation(opspec)
359	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
360	{{- end }}{{- /* if .HasListOutput */}}
361	{{- else }}
362	return scope.AddOperation(opspec)
363	{{- end }}{{- /* if .OutArgs */}}
364}
365`))
366)
367
368type attrWrapper struct {
369	op  *odpb.OpDef_AttrDef
370	api *adpb.ApiDef_Attr
371}
372
373func (a *attrWrapper) Name() string              { return a.api.Name }
374func (a *attrWrapper) RenameTo() string          { return a.api.RenameTo }
375func (a *attrWrapper) Description() string       { return a.api.Description }
376func (a *attrWrapper) Type() string              { return a.op.Type }
377func (a *attrWrapper) IsListAttr() bool          { return isListAttr(a.op) }
378func (a *attrWrapper) HasMinimum() bool          { return a.op.HasMinimum }
379func (a *attrWrapper) Minimum() int64            { return a.op.Minimum }
380func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
381
382type argWrapper struct {
383	op  *odpb.OpDef_ArgDef
384	api *adpb.ApiDef_Arg
385}
386
387func (a *argWrapper) Name() string        { return a.api.Name }
388func (a *argWrapper) RenameTo() string    { return a.api.RenameTo }
389func (a *argWrapper) Description() string { return a.api.Description }
390func (a *argWrapper) IsListArg() bool     { return isListArg(a.op) }
391
392type tmplArgs struct {
393	Op     *odpb.OpDef
394	APIDef *adpb.ApiDef
395	// Op.Attr is split into two categories
396	// (1) Required: These must be specified by the client and are thus
397	//     included in the function signature.
398	// (2) Optional: These need not be specified (as they have default
399	//     values) and thus do not appear in the function signature.
400	RequiredAttrs []*attrWrapper
401	OptionalAttrs []*attrWrapper
402	InArgs        []*argWrapper
403	// Input arguments ordered based on arg_order field of ApiDef.
404	InArgsReordered []*argWrapper
405	OutArgs         []*argWrapper
406}
407
408func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) {
409	ret := tmplArgs{Op: op, APIDef: apidef}
410
411	// Setup InArgs field
412	for i, in := range op.InputArg {
413		argCombined := argWrapper{op: in, api: apidef.InArg[i]}
414		ret.InArgs = append(ret.InArgs, &argCombined)
415	}
416
417	// Setup OutArgs field
418	for i, out := range op.OutputArg {
419		argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
420		ret.OutArgs = append(ret.OutArgs, &argCombined)
421	}
422
423	// Setup InArgsReordered field
424	for _, argName := range apidef.ArgOrder {
425		// Find the argument in op.InputArg
426		argIndex := -1
427		for i, in := range op.InputArg {
428			if in.Name == argName {
429				argIndex = i
430				break
431			}
432		}
433		if argIndex == -1 {
434			return nil, fmt.Errorf(
435				"couldn't find argument %s in ApiDef for op %s",
436				argName, op.Name)
437		}
438		argCombined := argWrapper{
439			op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
440		ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
441	}
442
443	if len(op.Attr) == 0 {
444		return &ret, nil
445	}
446	// Attributes related to the InputArg's type are inferred automatically
447	// and are not exposed to the client.
448	inferred := make(map[string]bool)
449	for _, in := range op.InputArg {
450		switch {
451		case in.TypeAttr != "":
452			inferred[in.TypeAttr] = true
453		case in.TypeListAttr != "":
454			inferred[in.TypeListAttr] = true
455		}
456		if in.NumberAttr != "" {
457			inferred[in.NumberAttr] = true
458		}
459	}
460	for i, attr := range op.Attr {
461		if inferred[attr.Name] {
462			continue
463		}
464		attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
465		if attr.DefaultValue == nil {
466			ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
467		} else {
468			ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
469		}
470	}
471	return &ret, nil
472}
473
474func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
475func (a *tmplArgs) DescribeArguments() bool {
476	for _, arg := range a.InArgs {
477		if arg.Description() != "" {
478			return true
479		}
480	}
481	for _, attr := range a.RequiredAttrs {
482		if attr.Description() != "" {
483			return true
484		}
485	}
486	return false
487
488}
489func (a *tmplArgs) DescribeOutputs() bool {
490	for _, arg := range a.OutArgs {
491		if arg.Description() != "" {
492			return true
493		}
494	}
495	return false
496}
497func (a *tmplArgs) HasListOutput() bool {
498	for _, arg := range a.OutArgs {
499		if arg.IsListArg() {
500			return true
501		}
502	}
503	return false
504}
505
506func makeComment(lines string) string {
507	return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
508}
509
510// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
511// to the corresponding type in Go.
512func goType(tfType string) (string, error) {
513	list, tfType := parseTFType(tfType)
514	var gotype string
515	switch tfType {
516	case "int":
517		gotype = "int64"
518	case "float":
519		gotype = "float32"
520	case "bool":
521		gotype = "bool"
522	case "type":
523		gotype = "tf.DataType"
524	case "shape":
525		gotype = "tf.Shape"
526	case "tensor":
527		gotype = "tf.Tensor"
528	case "string":
529		gotype = "string"
530	default:
531		return "", fmt.Errorf("%q is not a recognized DataType", tfType)
532	}
533	if list {
534		gotype = "[]" + gotype
535	}
536	return gotype, nil
537}
538
539func camelCase(snakeCase string) string {
540	words := strings.Split(snakeCase, "_")
541	for i, w := range words {
542		words[i] = strings.ToUpper(string(w[0])) + w[1:]
543	}
544	return strings.Join(words, "")
545}
546
547// identifier creates an identifier for s usable in the generated Go source
548// code.
549//
550// Avoids collisions with keywords and other identifiers used in the generated
551// code.
552func identifier(s string) string {
553	// Identifiers used in the generated code.
554	if s == "tf" || s == "scope" || s == "err" || s == "op" {
555		return s + "_"
556	}
557	for _, k := range keywords {
558		if s == k {
559			// Alternatively, make the first letter upper case.
560			return s + "_"
561		}
562	}
563	return s
564}
565
566func isListArg(argdef *odpb.OpDef_ArgDef) bool {
567	return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
568}
569
570func isListAttr(attrdef *odpb.OpDef_AttrDef) bool {
571	list, _ := parseTFType(attrdef.Type)
572	return list
573}
574
575func marshalProtoMessage(m proto.Message) string {
576	// Marshal proto message to string.
577	o := prototext.MarshalOptions{Multiline: false}
578	x := o.Format(m)
579
580	// Remove superfluous whitespace, if present.
581	//
582	// Go protobuf output is purposefully unstable, randomly adding
583	// whitespace.  See github.com/golang/protobuf/issues/1121
584	x = strings.ReplaceAll(x, "  ", " ")
585
586	// Remove the prefix of the string up to the first colon.
587	//
588	// This is useful when 's' corresponds to a "oneof" protocol buffer
589	// message. For example, consider the protocol buffer message:
590	//   oneof value { bool b = 1;  int64 i = 2; }
591	// proto.CompactTextString) will print "b:true", or "i:7" etc. The
592	// following strips out the leading "b:" or "i:".
593	y := strings.SplitN(x, ":", 2)
594	if len(y) < 2 {
595		return x
596	}
597	return y[1]
598}
599
600func parseTFType(tfType string) (list bool, typ string) {
601	const (
602		listPrefix = "list("
603		listSuffix = ")"
604	)
605	if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
606		return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
607	}
608	return false, tfType
609}
610