1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17// Package internal generates Go source code with functions for TensorFlow operations. 18// 19// The basic outline of the generated API is as follows: 20// 21// - One function for each TensorFlow operation 22// - The arguments to the function are the inputs and required attributes of the operation 23// - The function returns the outputs 24// - A function is also generated for each optional attribute of the operation. 25// 26// There is a possibility that there are name collisions between the functions 27// generated for ops and the functions generated for optional attributes. For 28// now, we ignore those, but will need to revisit if a collision is actually 29// encountered. 30package internal 31 32/* 33#include <stdlib.h> 34 35#include "tensorflow/c/c_api.h" 36*/ 37import "C" 38 39import ( 40 "fmt" 41 "io" 42 "io/ioutil" 43 "path" 44 "reflect" 45 "sort" 46 "strings" 47 "text/template" 48 "unsafe" 49 50 "google.golang.org/protobuf/encoding/prototext" 51 "google.golang.org/protobuf/proto" 52 53 adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto" 54 odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto" 55) 56 57// GenerateFunctionsForRegisteredOps writes a Go source code file to w 58// containing functions for each TensorFlow operation registered in the address 59// space of the calling process. 60// apidefDirs should be a contain of directories containing api_def_*.pbtxt 61// files to load. 62func GenerateFunctionsForRegisteredOps( 63 w io.Writer, apidefDirs []string) error { 64 ops, apimap, err := registeredOps() 65 if err != nil { 66 return err 67 } 68 for _, dir := range apidefDirs { 69 if err = updateAPIDefs(apimap, dir); err != nil { 70 return err 71 } 72 } 73 return generateFunctionsForOps(w, ops, apimap) 74} 75 76func registeredOps() (*odpb.OpList, *apiDefMap, error) { 77 buf := C.TF_GetAllOpList() 78 defer C.TF_DeleteBuffer(buf) 79 var ( 80 list = new(odpb.OpList) 81 size = int(buf.length) 82 // A []byte backed by C memory. 83 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 84 data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] 85 err = proto.Unmarshal(data, list) 86 ) 87 if err != nil { 88 return nil, nil, err 89 } 90 // Sort ops by name 91 sort.Slice(list.Op, func(i, j int) bool { 92 return list.Op[i].Name < list.Op[j].Name 93 }) 94 apimap, err := newAPIDefMap(list) 95 return list, apimap, err 96} 97 98func updateAPIDefs(m *apiDefMap, dir string) error { 99 files, err := ioutil.ReadDir(dir) 100 if err != nil { 101 return err 102 } 103 for _, file := range files { 104 if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") { 105 continue 106 } 107 data, err := ioutil.ReadFile(path.Join(dir, file.Name())) 108 if err != nil { 109 return fmt.Errorf("failed to read %q: %v", file.Name(), err) 110 } 111 if err = m.Put(string(data)); err != nil { 112 return fmt.Errorf("failed to process %q: %v", file.Name(), err) 113 } 114 } 115 return nil 116} 117 118func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error { 119 thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() 120 if err := tmplHeader.Execute(w, thisPackage); err != nil { 121 return err 122 } 123 denylist := map[string]bool{ 124 "Const": true, 125 "PyFunc": true, 126 "PyFuncStateless": true, 127 } 128 for _, op := range ops.Op { 129 if denylist[op.Name] { 130 continue 131 } 132 apidef, err := apimap.Get(op.Name) 133 if err != nil { 134 return err 135 } 136 if err := generateFunctionForOp(w, op, apidef); err != nil { 137 return err 138 } 139 } 140 return nil 141} 142 143func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error { 144 if strings.HasPrefix(op.Name, "_") { // Internal operation 145 return nil 146 } 147 // Ignore operations where the Go types corresponding to the TensorFlow 148 // type haven't been worked out (such as "func"s). 149 for _, a := range op.Attr { 150 if _, err := goType(a.Type); err != nil { 151 return nil 152 } 153 } 154 // Also, haven't figured out reference types yet, so ignore those too. 155 for _, a := range op.InputArg { 156 if a.IsRef { 157 return nil 158 } 159 } 160 for _, a := range op.OutputArg { 161 if a.IsRef { 162 return nil 163 } 164 } 165 if apidef.Summary == "" { 166 // Undocumented operation, perhaps a sign of not being ready to 167 // export. 168 return nil 169 } 170 tmplArgs, err := newTmplArgs(op, apidef) 171 if err != nil { 172 return err 173 } 174 return tmplOp.Execute(w, tmplArgs) 175} 176 177var ( 178 // Go keywords that cannot be used as identifiers. 179 // From https://golang.org/ref/spec#Keywords 180 keywords = []string{ 181 "break", "default", "func", "interface", "select", "case", 182 "defer", "go", "map", "struct", "chan", "else", "goto", 183 "package", "switch", "const", "fallthrough", "if", "range", 184 "type", "continue", "for", "import", "return", "var", 185 } 186 187 tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT 188// This file was machine generated by {{.}} 189// 190// WARNING: This generation of wrapper function for TensorFlow ops is in an 191// experimental state. The generated API can change without notice. 192 193package op 194 195import tf "github.com/tensorflow/tensorflow/tensorflow/go" 196 197// optionalAttr is an intentionally un-exported type to hide 198// details of how optional attributes to operations are implemented. 199type optionalAttr map[string]interface{} 200 201func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) { 202 size, err := op.OutputListSize(output) 203 if err != nil { 204 return nil, start, err 205 } 206 list := make([]tf.Output, size) 207 for i := 0; i < size; i++ { 208 list[i] = op.Output(start + i) 209 } 210 return list, start + size, nil 211} 212`)) 213 214 tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{ 215 "MakeComment": makeComment, 216 "GoType": goType, 217 "CamelCase": camelCase, 218 "Identifier": identifier, 219 "IsListArg": isListArg, 220 "IsListAttr": isListAttr, 221 "MarshalProtoMessage": marshalProtoMessage, 222 }).Parse(` 223{{if .OptionalAttrs -}} 224{{/* Type for specifying all optional attributes. */ -}} 225// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}. 226type {{.Op.Name}}Attr func(optionalAttr) 227 228{{range .OptionalAttrs}} 229// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value. 230{{- if .Description}} 231// 232// value: {{MakeComment .Description}} 233{{- end}} 234// If not specified, defaults to {{MarshalProtoMessage .DefaultValue}} 235{{- if .HasMinimum}} 236// 237// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} 238{{- end}} 239func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { 240 return func(m optionalAttr) { 241 m[{{printf "%q" .Name}}] = value 242 } 243} 244{{end}} 245{{end}} 246 247{{- /* Create a godoc friendly comment. */ -}} 248 249// {{MakeComment .APIDef.Summary}} 250 251{{- with .Op.Deprecation}} 252// 253// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} 254{{- end -}} 255 256{{- with .APIDef.Description}} 257// 258// {{MakeComment .}} 259{{- end -}} 260 261{{- if .DescribeArguments}} 262// 263// Arguments: 264{{- range .InArgsReordered}} 265// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 266{{- end -}} 267{{- range .RequiredAttrs}} 268// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 269{{- end -}} 270{{- end -}} 271 272{{- if (not .Op.OutputArg) }} 273// 274// Returns the created operation. 275{{- else }} 276{{- if .DescribeOutputs}} 277// 278{{- if eq (len .OutArgs) 1 }} 279// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}} 280{{- else }} 281// Returns: 282{{- range .OutArgs}} 283// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}} 284{{- end -}} 285{{- end -}} 286{{- end -}} 287{{- end -}} 288{{- /* 289 290 The function signature. 291 Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang 292*/}} 293func {{.Op.Name}} 294 295{{- /* 296 Fill in input arguments: 297 (1) The Scope 298 (2) All input arguments (which may be either []tf.Output or tf.Output) 299 (3) All required attributes 300 (4) Variadic list of optional attributes 301*/ -}} 302 303(scope *Scope 304{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}} 305{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}} 306{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} 307) 308 309{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}} 310 311{{if .OutArgs -}} 312({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}) 313{{- else -}} 314(o *tf.Operation) 315{{- end }} { 316 if scope.Err() != nil { 317 return 318 } 319 {{if .HasAttrs -}} 320 attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}} 321 {{if .OptionalAttrs -}} 322 for _, a := range optional { 323 a(attrs) 324 } 325 {{end -}} 326 {{end -}} 327 opspec := tf.OpSpec{ 328 Type: {{printf "%q" .Op.Name}}, 329 {{if .InArgs -}} 330 Input: []tf.Input{ 331 {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}} 332 }, 333 {{- end}} 334 {{- if .HasAttrs}} 335 Attrs: attrs, 336 {{- end}} 337 } 338 {{- if .OutArgs}} 339 {{- if .HasListOutput}} 340 op := scope.AddOperation(opspec) 341 if scope.Err() != nil { 342 return 343 } 344 var idx int 345 var err error 346 {{- range $i, $a := .OutArgs}} 347 {{- if $a.IsListArg}} 348 if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { 349 scope.UpdateErr({{printf "%q" $.Op.Name}}, err) 350 return 351 } 352 {{- else }} 353 {{Identifier .RenameTo}} = op.Output(idx) 354 {{- end }}{{- /* if IsListArg */}} 355 {{- end }}{{- /* range .OutArgs */}} 356 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}} 357 {{- else }} 358 op := scope.AddOperation(opspec) 359 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} 360 {{- end }}{{- /* if .HasListOutput */}} 361 {{- else }} 362 return scope.AddOperation(opspec) 363 {{- end }}{{- /* if .OutArgs */}} 364} 365`)) 366) 367 368type attrWrapper struct { 369 op *odpb.OpDef_AttrDef 370 api *adpb.ApiDef_Attr 371} 372 373func (a *attrWrapper) Name() string { return a.api.Name } 374func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } 375func (a *attrWrapper) Description() string { return a.api.Description } 376func (a *attrWrapper) Type() string { return a.op.Type } 377func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } 378func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } 379func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } 380func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } 381 382type argWrapper struct { 383 op *odpb.OpDef_ArgDef 384 api *adpb.ApiDef_Arg 385} 386 387func (a *argWrapper) Name() string { return a.api.Name } 388func (a *argWrapper) RenameTo() string { return a.api.RenameTo } 389func (a *argWrapper) Description() string { return a.api.Description } 390func (a *argWrapper) IsListArg() bool { return isListArg(a.op) } 391 392type tmplArgs struct { 393 Op *odpb.OpDef 394 APIDef *adpb.ApiDef 395 // Op.Attr is split into two categories 396 // (1) Required: These must be specified by the client and are thus 397 // included in the function signature. 398 // (2) Optional: These need not be specified (as they have default 399 // values) and thus do not appear in the function signature. 400 RequiredAttrs []*attrWrapper 401 OptionalAttrs []*attrWrapper 402 InArgs []*argWrapper 403 // Input arguments ordered based on arg_order field of ApiDef. 404 InArgsReordered []*argWrapper 405 OutArgs []*argWrapper 406} 407 408func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) { 409 ret := tmplArgs{Op: op, APIDef: apidef} 410 411 // Setup InArgs field 412 for i, in := range op.InputArg { 413 argCombined := argWrapper{op: in, api: apidef.InArg[i]} 414 ret.InArgs = append(ret.InArgs, &argCombined) 415 } 416 417 // Setup OutArgs field 418 for i, out := range op.OutputArg { 419 argCombined := argWrapper{op: out, api: apidef.OutArg[i]} 420 ret.OutArgs = append(ret.OutArgs, &argCombined) 421 } 422 423 // Setup InArgsReordered field 424 for _, argName := range apidef.ArgOrder { 425 // Find the argument in op.InputArg 426 argIndex := -1 427 for i, in := range op.InputArg { 428 if in.Name == argName { 429 argIndex = i 430 break 431 } 432 } 433 if argIndex == -1 { 434 return nil, fmt.Errorf( 435 "couldn't find argument %s in ApiDef for op %s", 436 argName, op.Name) 437 } 438 argCombined := argWrapper{ 439 op: op.InputArg[argIndex], api: apidef.InArg[argIndex]} 440 ret.InArgsReordered = append(ret.InArgsReordered, &argCombined) 441 } 442 443 if len(op.Attr) == 0 { 444 return &ret, nil 445 } 446 // Attributes related to the InputArg's type are inferred automatically 447 // and are not exposed to the client. 448 inferred := make(map[string]bool) 449 for _, in := range op.InputArg { 450 switch { 451 case in.TypeAttr != "": 452 inferred[in.TypeAttr] = true 453 case in.TypeListAttr != "": 454 inferred[in.TypeListAttr] = true 455 } 456 if in.NumberAttr != "" { 457 inferred[in.NumberAttr] = true 458 } 459 } 460 for i, attr := range op.Attr { 461 if inferred[attr.Name] { 462 continue 463 } 464 attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]} 465 if attr.DefaultValue == nil { 466 ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined) 467 } else { 468 ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined) 469 } 470 } 471 return &ret, nil 472} 473 474func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } 475func (a *tmplArgs) DescribeArguments() bool { 476 for _, arg := range a.InArgs { 477 if arg.Description() != "" { 478 return true 479 } 480 } 481 for _, attr := range a.RequiredAttrs { 482 if attr.Description() != "" { 483 return true 484 } 485 } 486 return false 487 488} 489func (a *tmplArgs) DescribeOutputs() bool { 490 for _, arg := range a.OutArgs { 491 if arg.Description() != "" { 492 return true 493 } 494 } 495 return false 496} 497func (a *tmplArgs) HasListOutput() bool { 498 for _, arg := range a.OutArgs { 499 if arg.IsListArg() { 500 return true 501 } 502 } 503 return false 504} 505 506func makeComment(lines string) string { 507 return strings.Join(strings.SplitAfter(lines, "\n"), "// ") 508} 509 510// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.) 511// to the corresponding type in Go. 512func goType(tfType string) (string, error) { 513 list, tfType := parseTFType(tfType) 514 var gotype string 515 switch tfType { 516 case "int": 517 gotype = "int64" 518 case "float": 519 gotype = "float32" 520 case "bool": 521 gotype = "bool" 522 case "type": 523 gotype = "tf.DataType" 524 case "shape": 525 gotype = "tf.Shape" 526 case "tensor": 527 gotype = "tf.Tensor" 528 case "string": 529 gotype = "string" 530 default: 531 return "", fmt.Errorf("%q is not a recognized DataType", tfType) 532 } 533 if list { 534 gotype = "[]" + gotype 535 } 536 return gotype, nil 537} 538 539func camelCase(snakeCase string) string { 540 words := strings.Split(snakeCase, "_") 541 for i, w := range words { 542 words[i] = strings.ToUpper(string(w[0])) + w[1:] 543 } 544 return strings.Join(words, "") 545} 546 547// identifier creates an identifier for s usable in the generated Go source 548// code. 549// 550// Avoids collisions with keywords and other identifiers used in the generated 551// code. 552func identifier(s string) string { 553 // Identifiers used in the generated code. 554 if s == "tf" || s == "scope" || s == "err" || s == "op" { 555 return s + "_" 556 } 557 for _, k := range keywords { 558 if s == k { 559 // Alternatively, make the first letter upper case. 560 return s + "_" 561 } 562 } 563 return s 564} 565 566func isListArg(argdef *odpb.OpDef_ArgDef) bool { 567 return argdef.TypeListAttr != "" || argdef.NumberAttr != "" 568} 569 570func isListAttr(attrdef *odpb.OpDef_AttrDef) bool { 571 list, _ := parseTFType(attrdef.Type) 572 return list 573} 574 575func marshalProtoMessage(m proto.Message) string { 576 // Marshal proto message to string. 577 o := prototext.MarshalOptions{Multiline: false} 578 x := o.Format(m) 579 580 // Remove superfluous whitespace, if present. 581 // 582 // Go protobuf output is purposefully unstable, randomly adding 583 // whitespace. See github.com/golang/protobuf/issues/1121 584 x = strings.ReplaceAll(x, " ", " ") 585 586 // Remove the prefix of the string up to the first colon. 587 // 588 // This is useful when 's' corresponds to a "oneof" protocol buffer 589 // message. For example, consider the protocol buffer message: 590 // oneof value { bool b = 1; int64 i = 2; } 591 // proto.CompactTextString) will print "b:true", or "i:7" etc. The 592 // following strips out the leading "b:" or "i:". 593 y := strings.SplitN(x, ":", 2) 594 if len(y) < 2 { 595 return x 596 } 597 return y[1] 598} 599 600func parseTFType(tfType string) (list bool, typ string) { 601 const ( 602 listPrefix = "list(" 603 listSuffix = ")" 604 ) 605 if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) { 606 return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix) 607 } 608 return false, tfType 609} 610