1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17// Package internal generates Go source code with functions for TensorFlow operations. 18// 19// The basic outline of the generated API is as follows: 20// 21// - One function for each TensorFlow operation 22// - The arguments to the function are the inputs and required attributes of the operation 23// - The function returns the outputs 24// - A function is also generated for each optional attribute of the operation. 25// 26// There is a possibility that there are name collisions between the functions 27// generated for ops and the functions generated for optional attributes. For 28// now, we ignore those, but will need to revisit if a collision is actually 29// encountered. 30package internal 31 32/* 33#include <stdlib.h> 34 35#include "tensorflow/c/c_api.h" 36*/ 37import "C" 38 39import ( 40 "fmt" 41 "io" 42 "io/ioutil" 43 "path" 44 "reflect" 45 "strings" 46 "text/template" 47 "unsafe" 48 49 "github.com/golang/protobuf/proto" 50 adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto" 51 odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto" 52) 53 54// GenerateFunctionsForRegisteredOps writes a Go source code file to w 55// containing functions for each TensorFlow operation registered in the address 56// space of the calling process. 57// apidefDirs should be a contain of directories containing api_def_*.pbtxt 58// files to load. 59func GenerateFunctionsForRegisteredOps( 60 w io.Writer, apidefDirs []string) error { 61 ops, apimap, err := registeredOps() 62 if err != nil { 63 return err 64 } 65 for _, dir := range apidefDirs { 66 if err = updateAPIDefs(apimap, dir); err != nil { 67 return err 68 } 69 } 70 return generateFunctionsForOps(w, ops, apimap) 71} 72 73func registeredOps() (*odpb.OpList, *apiDefMap, error) { 74 buf := C.TF_GetAllOpList() 75 defer C.TF_DeleteBuffer(buf) 76 var ( 77 list = new(odpb.OpList) 78 size = int(buf.length) 79 // A []byte backed by C memory. 80 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 81 data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] 82 err = proto.Unmarshal(data, list) 83 ) 84 if err != nil { 85 return nil, nil, err 86 } 87 apimap, err := newAPIDefMap(list) 88 return list, apimap, err 89} 90 91func updateAPIDefs(m *apiDefMap, dir string) error { 92 files, err := ioutil.ReadDir(dir) 93 if err != nil { 94 return err 95 } 96 for _, file := range files { 97 if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") { 98 continue 99 } 100 data, err := ioutil.ReadFile(path.Join(dir, file.Name())) 101 if err != nil { 102 return fmt.Errorf("failed to read %q: %v", file.Name(), err) 103 } 104 if err = m.Put(string(data)); err != nil { 105 return fmt.Errorf("failed to process %q: %v", file.Name(), err) 106 } 107 } 108 return nil 109} 110 111func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error { 112 thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() 113 if err := tmplHeader.Execute(w, thisPackage); err != nil { 114 return err 115 } 116 denylist := map[string]bool{ 117 "Const": true, 118 "PyFunc": true, 119 "PyFuncStateless": true, 120 } 121 for _, op := range ops.Op { 122 if denylist[op.Name] { 123 continue 124 } 125 apidef, err := apimap.Get(op.Name) 126 if err != nil { 127 return err 128 } 129 if err := generateFunctionForOp(w, op, apidef); err != nil { 130 return err 131 } 132 } 133 return nil 134} 135 136func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error { 137 if strings.HasPrefix(op.Name, "_") { // Internal operation 138 return nil 139 } 140 // Ignore operations where the Go types corresponding to the TensorFlow 141 // type haven't been worked out (such as "func"s). 142 for _, a := range op.Attr { 143 if _, err := goType(a.Type); err != nil { 144 return nil 145 } 146 } 147 // Also, haven't figured out reference types yet, so ignore those too. 148 for _, a := range op.InputArg { 149 if a.IsRef { 150 return nil 151 } 152 } 153 for _, a := range op.OutputArg { 154 if a.IsRef { 155 return nil 156 } 157 } 158 if apidef.Summary == "" { 159 // Undocumented operation, perhaps a sign of not being ready to 160 // export. 161 return nil 162 } 163 tmplArgs, err := newTmplArgs(op, apidef) 164 if err != nil { 165 return err 166 } 167 return tmplOp.Execute(w, tmplArgs) 168} 169 170var ( 171 // Go keywords that cannot be used as identifiers. 172 // From https://golang.org/ref/spec#Keywords 173 keywords = []string{ 174 "break", "default", "func", "interface", "select", "case", 175 "defer", "go", "map", "struct", "chan", "else", "goto", 176 "package", "switch", "const", "fallthrough", "if", "range", 177 "type", "continue", "for", "import", "return", "var", 178 } 179 180 tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT 181// This file was machine generated by {{.}} 182// 183// WARNING: This generation of wrapper function for TensorFlow ops is in an 184// experimental state. The generated API can change without notice. 185 186package op 187 188import tf "github.com/tensorflow/tensorflow/tensorflow/go" 189 190// optionalAttr is an intentionally un-exported type to hide 191// details of how optional attributes to operations are implemented. 192type optionalAttr map[string]interface{} 193 194func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) { 195 size, err := op.OutputListSize(output) 196 if err != nil { 197 return nil, start, err 198 } 199 list := make([]tf.Output, size) 200 for i := 0; i < size; i++ { 201 list[i] = op.Output(start + i) 202 } 203 return list, start + size, nil 204} 205`)) 206 207 tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{ 208 "MakeComment": makeComment, 209 "GoType": goType, 210 "CamelCase": camelCase, 211 "Identifier": identifier, 212 "IsListArg": isListArg, 213 "IsListAttr": isListAttr, 214 "StripLeadingColon": stripLeadingColon, 215 }).Parse(` 216{{if .OptionalAttrs -}} 217{{/* Type for specifying all optional attributes. */ -}} 218// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}. 219type {{.Op.Name}}Attr func(optionalAttr) 220 221{{range .OptionalAttrs}} 222// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value. 223{{- if .Description}} 224// 225// value: {{MakeComment .Description}} 226{{- end}} 227// If not specified, defaults to {{StripLeadingColon .DefaultValue}} 228{{- if .HasMinimum}} 229// 230// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} 231{{- end}} 232func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { 233 return func(m optionalAttr) { 234 m[{{printf "%q" .Name}}] = value 235 } 236} 237{{end}} 238{{end}} 239 240{{- /* Create a godoc friendly comment. */ -}} 241 242// {{MakeComment .APIDef.Summary}} 243 244{{- with .Op.Deprecation}} 245// 246// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} 247{{- end -}} 248 249{{- with .APIDef.Description}} 250// 251// {{MakeComment .}} 252{{- end -}} 253 254{{- if .DescribeArguments}} 255// 256// Arguments: 257{{- range .InArgsReordered}} 258// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 259{{- end -}} 260{{- range .RequiredAttrs}} 261// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 262{{- end -}} 263{{- end -}} 264 265{{- if (not .Op.OutputArg) }} 266// 267// Returns the created operation. 268{{- else }} 269{{- if .DescribeOutputs}} 270// 271{{- if eq (len .OutArgs) 1 }} 272// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}} 273{{- else }} 274// Returns: 275{{- range .OutArgs}} 276// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}} 277{{- end -}} 278{{- end -}} 279{{- end -}} 280{{- end -}} 281{{- /* 282 283 The function signature. 284 Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang 285*/}} 286func {{.Op.Name}} 287 288{{- /* 289 Fill in input arguments: 290 (1) The Scope 291 (2) All input arguments (which may be either []tf.Output or tf.Output) 292 (3) All required attributes 293 (4) Variadic list of optional attributes 294*/ -}} 295 296(scope *Scope 297{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}} 298{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}} 299{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} 300) 301 302{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}} 303 304{{if .OutArgs -}} 305({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}) 306{{- else -}} 307(o *tf.Operation) 308{{- end }} { 309 if scope.Err() != nil { 310 return 311 } 312 {{if .HasAttrs -}} 313 attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}} 314 {{if .OptionalAttrs -}} 315 for _, a := range optional { 316 a(attrs) 317 } 318 {{end -}} 319 {{end -}} 320 opspec := tf.OpSpec{ 321 Type: {{printf "%q" .Op.Name}}, 322 {{if .InArgs -}} 323 Input: []tf.Input{ 324 {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}} 325 }, 326 {{- end}} 327 {{- if .HasAttrs}} 328 Attrs: attrs, 329 {{- end}} 330 } 331 {{- if .OutArgs}} 332 {{- if .HasListOutput}} 333 op := scope.AddOperation(opspec) 334 if scope.Err() != nil { 335 return 336 } 337 var idx int 338 var err error 339 {{- range $i, $a := .OutArgs}} 340 {{- if $a.IsListArg}} 341 if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { 342 scope.UpdateErr({{printf "%q" $.Op.Name}}, err) 343 return 344 } 345 {{- else }} 346 {{Identifier .RenameTo}} = op.Output(idx) 347 {{- end }}{{- /* if IsListArg */}} 348 {{- end }}{{- /* range .OutArgs */}} 349 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}} 350 {{- else }} 351 op := scope.AddOperation(opspec) 352 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} 353 {{- end }}{{- /* if .HasListOutput */}} 354 {{- else }} 355 return scope.AddOperation(opspec) 356 {{- end }}{{- /* if .OutArgs */}} 357} 358`)) 359) 360 361type attrWrapper struct { 362 op *odpb.OpDef_AttrDef 363 api *adpb.ApiDef_Attr 364} 365 366func (a *attrWrapper) Name() string { return a.api.Name } 367func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } 368func (a *attrWrapper) Description() string { return a.api.Description } 369func (a *attrWrapper) Type() string { return a.op.Type } 370func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } 371func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } 372func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } 373func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } 374 375type argWrapper struct { 376 op *odpb.OpDef_ArgDef 377 api *adpb.ApiDef_Arg 378} 379 380func (a *argWrapper) Name() string { return a.api.Name } 381func (a *argWrapper) RenameTo() string { return a.api.RenameTo } 382func (a *argWrapper) Description() string { return a.api.Description } 383func (a *argWrapper) IsListArg() bool { return isListArg(a.op) } 384 385type tmplArgs struct { 386 Op *odpb.OpDef 387 APIDef *adpb.ApiDef 388 // Op.Attr is split into two categories 389 // (1) Required: These must be specified by the client and are thus 390 // included in the function signature. 391 // (2) Optional: These need not be specified (as they have default 392 // values) and thus do not appear in the function signature. 393 RequiredAttrs []*attrWrapper 394 OptionalAttrs []*attrWrapper 395 InArgs []*argWrapper 396 // Input arguments ordered based on arg_order field of ApiDef. 397 InArgsReordered []*argWrapper 398 OutArgs []*argWrapper 399} 400 401func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) { 402 ret := tmplArgs{Op: op, APIDef: apidef} 403 404 // Setup InArgs field 405 for i, in := range op.InputArg { 406 argCombined := argWrapper{op: in, api: apidef.InArg[i]} 407 ret.InArgs = append(ret.InArgs, &argCombined) 408 } 409 410 // Setup OutArgs field 411 for i, out := range op.OutputArg { 412 argCombined := argWrapper{op: out, api: apidef.OutArg[i]} 413 ret.OutArgs = append(ret.OutArgs, &argCombined) 414 } 415 416 // Setup InArgsReordered field 417 for _, argName := range apidef.ArgOrder { 418 // Find the argument in op.InputArg 419 argIndex := -1 420 for i, in := range op.InputArg { 421 if in.Name == argName { 422 argIndex = i 423 break 424 } 425 } 426 if argIndex == -1 { 427 return nil, fmt.Errorf( 428 "couldn't find argument %s in ApiDef for op %s", 429 argName, op.Name) 430 } 431 argCombined := argWrapper{ 432 op: op.InputArg[argIndex], api: apidef.InArg[argIndex]} 433 ret.InArgsReordered = append(ret.InArgsReordered, &argCombined) 434 } 435 436 if len(op.Attr) == 0 { 437 return &ret, nil 438 } 439 // Attributes related to the InputArg's type are inferred automatically 440 // and are not exposed to the client. 441 inferred := make(map[string]bool) 442 for _, in := range op.InputArg { 443 switch { 444 case in.TypeAttr != "": 445 inferred[in.TypeAttr] = true 446 case in.TypeListAttr != "": 447 inferred[in.TypeListAttr] = true 448 } 449 if in.NumberAttr != "" { 450 inferred[in.NumberAttr] = true 451 } 452 } 453 for i, attr := range op.Attr { 454 if inferred[attr.Name] { 455 continue 456 } 457 attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]} 458 if attr.DefaultValue == nil { 459 ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined) 460 } else { 461 ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined) 462 } 463 } 464 return &ret, nil 465} 466 467func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } 468func (a *tmplArgs) DescribeArguments() bool { 469 for _, arg := range a.InArgs { 470 if arg.Description() != "" { 471 return true 472 } 473 } 474 for _, attr := range a.RequiredAttrs { 475 if attr.Description() != "" { 476 return true 477 } 478 } 479 return false 480 481} 482func (a *tmplArgs) DescribeOutputs() bool { 483 for _, arg := range a.OutArgs { 484 if arg.Description() != "" { 485 return true 486 } 487 } 488 return false 489} 490func (a *tmplArgs) HasListOutput() bool { 491 for _, arg := range a.OutArgs { 492 if arg.IsListArg() { 493 return true 494 } 495 } 496 return false 497} 498 499func makeComment(lines string) string { 500 return strings.Join(strings.SplitAfter(lines, "\n"), "// ") 501} 502 503// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.) 504// to the corresponding type in Go. 505func goType(tfType string) (string, error) { 506 list, tfType := parseTFType(tfType) 507 var gotype string 508 switch tfType { 509 case "int": 510 gotype = "int64" 511 case "float": 512 gotype = "float32" 513 case "bool": 514 gotype = "bool" 515 case "type": 516 gotype = "tf.DataType" 517 case "shape": 518 gotype = "tf.Shape" 519 case "tensor": 520 gotype = "tf.Tensor" 521 case "string": 522 gotype = "string" 523 default: 524 return "", fmt.Errorf("%q is not a recognized DataType", tfType) 525 } 526 if list { 527 gotype = "[]" + gotype 528 } 529 return gotype, nil 530} 531 532func camelCase(snakeCase string) string { 533 words := strings.Split(snakeCase, "_") 534 for i, w := range words { 535 words[i] = strings.ToUpper(string(w[0])) + w[1:] 536 } 537 return strings.Join(words, "") 538} 539 540// identifier creates an identifier for s usable in the generated Go source 541// code. 542// 543// Avoids collisions with keywords and other identifiers used in the generated 544// code. 545func identifier(s string) string { 546 // Identifiers used in the generated code. 547 if s == "tf" || s == "scope" || s == "err" || s == "op" { 548 return s + "_" 549 } 550 for _, k := range keywords { 551 if s == k { 552 // Alternatively, make the first letter upper case. 553 return s + "_" 554 } 555 } 556 return s 557} 558 559func isListArg(argdef *odpb.OpDef_ArgDef) bool { 560 return argdef.TypeListAttr != "" || argdef.NumberAttr != "" 561} 562 563func isListAttr(attrdef *odpb.OpDef_AttrDef) bool { 564 list, _ := parseTFType(attrdef.Type) 565 return list 566} 567 568// stripLeadingColon removes the prefix of the string up to the first colon. 569// 570// This is useful when 's' corresponds to a "oneof" protocol buffer message. 571// For example, consider the protocol buffer message: 572// oneof value { bool b = 1; int64 i = 2; } 573// proto.CompactTextString) will print "b:true", or "i:7" etc. This function 574// strips out the leading "b:" or "i:". 575func stripLeadingColon(m proto.Message) string { 576 x := proto.CompactTextString(m) 577 y := strings.SplitN(x, ":", 2) 578 if len(y) < 2 { 579 return x 580 } 581 return y[1] 582} 583 584func parseTFType(tfType string) (list bool, typ string) { 585 const ( 586 listPrefix = "list(" 587 listSuffix = ")" 588 ) 589 if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) { 590 return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix) 591 } 592 return false, tfType 593} 594