• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17// Package internal generates Go source code with functions for TensorFlow operations.
18//
19// The basic outline of the generated API is as follows:
20//
21// - One function for each TensorFlow operation
22// - The arguments to the function are the inputs and required attributes of the operation
23// - The function returns the outputs
24// - A function is also generated for each optional attribute of the operation.
25//
26// There is a possibility that there are name collisions between the functions
27// generated for ops and the functions generated for optional attributes. For
28// now, we ignore those, but will need to revisit if a collision is actually
29// encountered.
30package internal
31
32/*
33#include <stdlib.h>
34
35#include "tensorflow/c/c_api.h"
36*/
37import "C"
38
39import (
40	"fmt"
41	"io"
42	"io/ioutil"
43	"path"
44	"reflect"
45	"strings"
46	"text/template"
47	"unsafe"
48
49	"github.com/golang/protobuf/proto"
50	adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto"
51	odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"
52)
53
54// GenerateFunctionsForRegisteredOps writes a Go source code file to w
55// containing functions for each TensorFlow operation registered in the address
56// space of the calling process.
57// apidefDirs should be a contain of directories containing api_def_*.pbtxt
58// files to load.
59func GenerateFunctionsForRegisteredOps(
60	w io.Writer, apidefDirs []string) error {
61	ops, apimap, err := registeredOps()
62	if err != nil {
63		return err
64	}
65	for _, dir := range apidefDirs {
66		if err = updateAPIDefs(apimap, dir); err != nil {
67			return err
68		}
69	}
70	return generateFunctionsForOps(w, ops, apimap)
71}
72
73func registeredOps() (*odpb.OpList, *apiDefMap, error) {
74	buf := C.TF_GetAllOpList()
75	defer C.TF_DeleteBuffer(buf)
76	var (
77		list = new(odpb.OpList)
78		size = int(buf.length)
79		// A []byte backed by C memory.
80		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
81		data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
82		err  = proto.Unmarshal(data, list)
83	)
84	if err != nil {
85		return nil, nil, err
86	}
87	apimap, err := newAPIDefMap(list)
88	return list, apimap, err
89}
90
91func updateAPIDefs(m *apiDefMap, dir string) error {
92	files, err := ioutil.ReadDir(dir)
93	if err != nil {
94		return err
95	}
96	for _, file := range files {
97		if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") {
98			continue
99		}
100		data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
101		if err != nil {
102			return fmt.Errorf("failed to read %q: %v", file.Name(), err)
103		}
104		if err = m.Put(string(data)); err != nil {
105			return fmt.Errorf("failed to process %q: %v", file.Name(), err)
106		}
107	}
108	return nil
109}
110
111func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error {
112	thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
113	if err := tmplHeader.Execute(w, thisPackage); err != nil {
114		return err
115	}
116	denylist := map[string]bool{
117		"Const":           true,
118		"PyFunc":          true,
119		"PyFuncStateless": true,
120	}
121	for _, op := range ops.Op {
122		if denylist[op.Name] {
123			continue
124		}
125		apidef, err := apimap.Get(op.Name)
126		if err != nil {
127			return err
128		}
129		if err := generateFunctionForOp(w, op, apidef); err != nil {
130			return err
131		}
132	}
133	return nil
134}
135
136func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error {
137	if strings.HasPrefix(op.Name, "_") { // Internal operation
138		return nil
139	}
140	// Ignore operations where the Go types corresponding to the TensorFlow
141	// type haven't been worked out (such as "func"s).
142	for _, a := range op.Attr {
143		if _, err := goType(a.Type); err != nil {
144			return nil
145		}
146	}
147	// Also, haven't figured out reference types yet, so ignore those too.
148	for _, a := range op.InputArg {
149		if a.IsRef {
150			return nil
151		}
152	}
153	for _, a := range op.OutputArg {
154		if a.IsRef {
155			return nil
156		}
157	}
158	if apidef.Summary == "" {
159		// Undocumented operation, perhaps a sign of not being ready to
160		// export.
161		return nil
162	}
163	tmplArgs, err := newTmplArgs(op, apidef)
164	if err != nil {
165		return err
166	}
167	return tmplOp.Execute(w, tmplArgs)
168}
169
170var (
171	// Go keywords that cannot be used as identifiers.
172	// From https://golang.org/ref/spec#Keywords
173	keywords = []string{
174		"break", "default", "func", "interface", "select", "case",
175		"defer", "go", "map", "struct", "chan", "else", "goto",
176		"package", "switch", "const", "fallthrough", "if", "range",
177		"type", "continue", "for", "import", "return", "var",
178	}
179
180	tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
181// This file was machine generated by {{.}}
182//
183// WARNING: This generation of wrapper function for TensorFlow ops is in an
184// experimental state. The generated API can change without notice.
185
186package op
187
188import tf "github.com/tensorflow/tensorflow/tensorflow/go"
189
190// optionalAttr is an intentionally un-exported type to hide
191// details of how optional attributes to operations are implemented.
192type optionalAttr map[string]interface{}
193
194func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
195	size, err := op.OutputListSize(output)
196	if err != nil {
197		return nil, start, err
198	}
199	list := make([]tf.Output, size)
200	for i := 0; i < size; i++ {
201		list[i] = op.Output(start + i)
202	}
203	return list, start + size, nil
204}
205`))
206
207	tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
208		"MakeComment":       makeComment,
209		"GoType":            goType,
210		"CamelCase":         camelCase,
211		"Identifier":        identifier,
212		"IsListArg":         isListArg,
213		"IsListAttr":        isListAttr,
214		"StripLeadingColon": stripLeadingColon,
215	}).Parse(`
216{{if .OptionalAttrs -}}
217{{/* Type for specifying all optional attributes. */ -}}
218// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
219type {{.Op.Name}}Attr func(optionalAttr)
220
221{{range .OptionalAttrs}}
222// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
223{{- if .Description}}
224//
225// value: {{MakeComment .Description}}
226{{- end}}
227// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
228{{- if .HasMinimum}}
229//
230// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
231{{- end}}
232func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
233	return func(m optionalAttr) {
234		m[{{printf "%q" .Name}}] = value
235	}
236}
237{{end}}
238{{end}}
239
240{{- /* Create a godoc friendly comment. */ -}}
241
242// {{MakeComment .APIDef.Summary}}
243
244{{- with .Op.Deprecation}}
245//
246// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
247{{- end -}}
248
249{{- with .APIDef.Description}}
250//
251// {{MakeComment .}}
252{{- end -}}
253
254{{- if .DescribeArguments}}
255//
256// Arguments:
257{{- range .InArgsReordered}}
258//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
259{{- end -}}
260{{- range .RequiredAttrs}}
261//	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
262{{- end -}}
263{{- end -}}
264
265{{- if (not .Op.OutputArg) }}
266//
267// Returns the created operation.
268{{- else }}
269{{- if .DescribeOutputs}}
270//
271{{- if eq (len .OutArgs) 1 }}
272// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
273{{- else }}
274// Returns:
275{{- range .OutArgs}}
276//	{{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
277{{- end -}}
278{{- end -}}
279{{- end -}}
280{{- end -}}
281{{- /*
282
283  The function signature.
284  Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
285*/}}
286func {{.Op.Name}}
287
288{{- /*
289  Fill in input arguments:
290  (1) The Scope
291  (2) All input arguments (which may be either []tf.Output or tf.Output)
292  (3) All required attributes
293  (4) Variadic list of optional attributes
294*/ -}}
295
296(scope *Scope
297{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
298{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
299{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
300)
301
302{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
303
304{{if .OutArgs -}}
305({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
306{{- else -}}
307(o *tf.Operation)
308{{- end }} {
309	if scope.Err() != nil {
310		return
311	}
312	{{if .HasAttrs -}}
313	attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
314	{{if .OptionalAttrs -}}
315	for _, a := range optional {
316		a(attrs)
317	}
318	{{end -}}
319	{{end -}}
320	opspec := tf.OpSpec{
321		Type: {{printf "%q" .Op.Name}},
322		{{if .InArgs -}}
323		Input: []tf.Input{
324			{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
325		},
326		{{- end}}
327		{{- if .HasAttrs}}
328		Attrs: attrs,
329		{{- end}}
330	}
331	{{- if .OutArgs}}
332	{{- if .HasListOutput}}
333	op := scope.AddOperation(opspec)
334	if scope.Err() != nil {
335		return
336	}
337	var idx int
338	var err error
339	{{- range $i, $a := .OutArgs}}
340	{{- if $a.IsListArg}}
341	if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
342		scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
343		return
344	}
345	{{- else }}
346	{{Identifier .RenameTo}} = op.Output(idx)
347	{{- end }}{{- /* if IsListArg */}}
348	{{- end }}{{- /* range .OutArgs */}}
349	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
350	{{- else }}
351	op := scope.AddOperation(opspec)
352	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
353	{{- end }}{{- /* if .HasListOutput */}}
354	{{- else }}
355	return scope.AddOperation(opspec)
356	{{- end }}{{- /* if .OutArgs */}}
357}
358`))
359)
360
361type attrWrapper struct {
362	op  *odpb.OpDef_AttrDef
363	api *adpb.ApiDef_Attr
364}
365
366func (a *attrWrapper) Name() string              { return a.api.Name }
367func (a *attrWrapper) RenameTo() string          { return a.api.RenameTo }
368func (a *attrWrapper) Description() string       { return a.api.Description }
369func (a *attrWrapper) Type() string              { return a.op.Type }
370func (a *attrWrapper) IsListAttr() bool          { return isListAttr(a.op) }
371func (a *attrWrapper) HasMinimum() bool          { return a.op.HasMinimum }
372func (a *attrWrapper) Minimum() int64            { return a.op.Minimum }
373func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
374
375type argWrapper struct {
376	op  *odpb.OpDef_ArgDef
377	api *adpb.ApiDef_Arg
378}
379
380func (a *argWrapper) Name() string        { return a.api.Name }
381func (a *argWrapper) RenameTo() string    { return a.api.RenameTo }
382func (a *argWrapper) Description() string { return a.api.Description }
383func (a *argWrapper) IsListArg() bool     { return isListArg(a.op) }
384
385type tmplArgs struct {
386	Op     *odpb.OpDef
387	APIDef *adpb.ApiDef
388	// Op.Attr is split into two categories
389	// (1) Required: These must be specified by the client and are thus
390	//     included in the function signature.
391	// (2) Optional: These need not be specified (as they have default
392	//     values) and thus do not appear in the function signature.
393	RequiredAttrs []*attrWrapper
394	OptionalAttrs []*attrWrapper
395	InArgs        []*argWrapper
396	// Input arguments ordered based on arg_order field of ApiDef.
397	InArgsReordered []*argWrapper
398	OutArgs         []*argWrapper
399}
400
401func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) {
402	ret := tmplArgs{Op: op, APIDef: apidef}
403
404	// Setup InArgs field
405	for i, in := range op.InputArg {
406		argCombined := argWrapper{op: in, api: apidef.InArg[i]}
407		ret.InArgs = append(ret.InArgs, &argCombined)
408	}
409
410	// Setup OutArgs field
411	for i, out := range op.OutputArg {
412		argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
413		ret.OutArgs = append(ret.OutArgs, &argCombined)
414	}
415
416	// Setup InArgsReordered field
417	for _, argName := range apidef.ArgOrder {
418		// Find the argument in op.InputArg
419		argIndex := -1
420		for i, in := range op.InputArg {
421			if in.Name == argName {
422				argIndex = i
423				break
424			}
425		}
426		if argIndex == -1 {
427			return nil, fmt.Errorf(
428				"couldn't find argument %s in ApiDef for op %s",
429				argName, op.Name)
430		}
431		argCombined := argWrapper{
432			op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
433		ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
434	}
435
436	if len(op.Attr) == 0 {
437		return &ret, nil
438	}
439	// Attributes related to the InputArg's type are inferred automatically
440	// and are not exposed to the client.
441	inferred := make(map[string]bool)
442	for _, in := range op.InputArg {
443		switch {
444		case in.TypeAttr != "":
445			inferred[in.TypeAttr] = true
446		case in.TypeListAttr != "":
447			inferred[in.TypeListAttr] = true
448		}
449		if in.NumberAttr != "" {
450			inferred[in.NumberAttr] = true
451		}
452	}
453	for i, attr := range op.Attr {
454		if inferred[attr.Name] {
455			continue
456		}
457		attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
458		if attr.DefaultValue == nil {
459			ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
460		} else {
461			ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
462		}
463	}
464	return &ret, nil
465}
466
467func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
468func (a *tmplArgs) DescribeArguments() bool {
469	for _, arg := range a.InArgs {
470		if arg.Description() != "" {
471			return true
472		}
473	}
474	for _, attr := range a.RequiredAttrs {
475		if attr.Description() != "" {
476			return true
477		}
478	}
479	return false
480
481}
482func (a *tmplArgs) DescribeOutputs() bool {
483	for _, arg := range a.OutArgs {
484		if arg.Description() != "" {
485			return true
486		}
487	}
488	return false
489}
490func (a *tmplArgs) HasListOutput() bool {
491	for _, arg := range a.OutArgs {
492		if arg.IsListArg() {
493			return true
494		}
495	}
496	return false
497}
498
499func makeComment(lines string) string {
500	return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
501}
502
503// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
504// to the corresponding type in Go.
505func goType(tfType string) (string, error) {
506	list, tfType := parseTFType(tfType)
507	var gotype string
508	switch tfType {
509	case "int":
510		gotype = "int64"
511	case "float":
512		gotype = "float32"
513	case "bool":
514		gotype = "bool"
515	case "type":
516		gotype = "tf.DataType"
517	case "shape":
518		gotype = "tf.Shape"
519	case "tensor":
520		gotype = "tf.Tensor"
521	case "string":
522		gotype = "string"
523	default:
524		return "", fmt.Errorf("%q is not a recognized DataType", tfType)
525	}
526	if list {
527		gotype = "[]" + gotype
528	}
529	return gotype, nil
530}
531
532func camelCase(snakeCase string) string {
533	words := strings.Split(snakeCase, "_")
534	for i, w := range words {
535		words[i] = strings.ToUpper(string(w[0])) + w[1:]
536	}
537	return strings.Join(words, "")
538}
539
540// identifier creates an identifier for s usable in the generated Go source
541// code.
542//
543// Avoids collisions with keywords and other identifiers used in the generated
544// code.
545func identifier(s string) string {
546	// Identifiers used in the generated code.
547	if s == "tf" || s == "scope" || s == "err" || s == "op" {
548		return s + "_"
549	}
550	for _, k := range keywords {
551		if s == k {
552			// Alternatively, make the first letter upper case.
553			return s + "_"
554		}
555	}
556	return s
557}
558
559func isListArg(argdef *odpb.OpDef_ArgDef) bool {
560	return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
561}
562
563func isListAttr(attrdef *odpb.OpDef_AttrDef) bool {
564	list, _ := parseTFType(attrdef.Type)
565	return list
566}
567
568// stripLeadingColon removes the prefix of the string up to the first colon.
569//
570// This is useful when 's' corresponds to a "oneof" protocol buffer message.
571// For example, consider the protocol buffer message:
572//   oneof value { bool b = 1;  int64 i = 2; }
573// proto.CompactTextString) will print "b:true", or "i:7" etc. This function
574// strips out the leading "b:" or "i:".
575func stripLeadingColon(m proto.Message) string {
576	x := proto.CompactTextString(m)
577	y := strings.SplitN(x, ":", 2)
578	if len(y) < 2 {
579		return x
580	}
581	return y[1]
582}
583
584func parseTFType(tfType string) (list bool, typ string) {
585	const (
586		listPrefix = "list("
587		listSuffix = ")"
588	)
589	if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
590		return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
591	}
592	return false, tfType
593}
594