1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17// Package internal generates Go source code with functions for TensorFlow operations. 18// 19// The basic outline of the generated API is as follows: 20// 21// - One function for each TensorFlow operation 22// - The arguments to the function are the inputs and required attributes of the operation 23// - The function returns the outputs 24// - A function is also generated for each optional attribute of the operation. 25// 26// There is a possibility that there are name collisions between the functions 27// generated for ops and the functions generated for optional attributes. For 28// now, we ignore those, but will need to revisit if a collision is actually 29// encountered. 30package internal 31 32/* 33#include <stdlib.h> 34 35#include "tensorflow/c/c_api.h" 36*/ 37import "C" 38 39import ( 40 "fmt" 41 "io" 42 "io/ioutil" 43 "path" 44 "reflect" 45 "strings" 46 "text/template" 47 "unsafe" 48 49 "google.golang.org/protobuf/encoding/prototext" 50 "google.golang.org/protobuf/proto" 51 adpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto" 52 odpb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto" 53) 54 55// GenerateFunctionsForRegisteredOps writes a Go source code file to w 56// containing functions for each TensorFlow operation registered in the address 57// space of the calling process. 58// apidefDirs should be a contain of directories containing api_def_*.pbtxt 59// files to load. 60func GenerateFunctionsForRegisteredOps( 61 w io.Writer, apidefDirs []string) error { 62 ops, apimap, err := registeredOps() 63 if err != nil { 64 return err 65 } 66 for _, dir := range apidefDirs { 67 if err = updateAPIDefs(apimap, dir); err != nil { 68 return err 69 } 70 } 71 return generateFunctionsForOps(w, ops, apimap) 72} 73 74func registeredOps() (*odpb.OpList, *apiDefMap, error) { 75 buf := C.TF_GetAllOpList() 76 defer C.TF_DeleteBuffer(buf) 77 var ( 78 list = new(odpb.OpList) 79 size = int(buf.length) 80 // A []byte backed by C memory. 81 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 82 data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] 83 err = proto.Unmarshal(data, list) 84 ) 85 if err != nil { 86 return nil, nil, err 87 } 88 apimap, err := newAPIDefMap(list) 89 return list, apimap, err 90} 91 92func updateAPIDefs(m *apiDefMap, dir string) error { 93 files, err := ioutil.ReadDir(dir) 94 if err != nil { 95 return err 96 } 97 for _, file := range files { 98 if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") { 99 continue 100 } 101 data, err := ioutil.ReadFile(path.Join(dir, file.Name())) 102 if err != nil { 103 return fmt.Errorf("failed to read %q: %v", file.Name(), err) 104 } 105 if err = m.Put(string(data)); err != nil { 106 return fmt.Errorf("failed to process %q: %v", file.Name(), err) 107 } 108 } 109 return nil 110} 111 112func generateFunctionsForOps(w io.Writer, ops *odpb.OpList, apimap *apiDefMap) error { 113 thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() 114 if err := tmplHeader.Execute(w, thisPackage); err != nil { 115 return err 116 } 117 denylist := map[string]bool{ 118 "Const": true, 119 "PyFunc": true, 120 "PyFuncStateless": true, 121 } 122 for _, op := range ops.Op { 123 if denylist[op.Name] { 124 continue 125 } 126 apidef, err := apimap.Get(op.Name) 127 if err != nil { 128 return err 129 } 130 if err := generateFunctionForOp(w, op, apidef); err != nil { 131 return err 132 } 133 } 134 return nil 135} 136 137func generateFunctionForOp(w io.Writer, op *odpb.OpDef, apidef *adpb.ApiDef) error { 138 if strings.HasPrefix(op.Name, "_") { // Internal operation 139 return nil 140 } 141 // Ignore operations where the Go types corresponding to the TensorFlow 142 // type haven't been worked out (such as "func"s). 143 for _, a := range op.Attr { 144 if _, err := goType(a.Type); err != nil { 145 return nil 146 } 147 } 148 // Also, haven't figured out reference types yet, so ignore those too. 149 for _, a := range op.InputArg { 150 if a.IsRef { 151 return nil 152 } 153 } 154 for _, a := range op.OutputArg { 155 if a.IsRef { 156 return nil 157 } 158 } 159 if apidef.Summary == "" { 160 // Undocumented operation, perhaps a sign of not being ready to 161 // export. 162 return nil 163 } 164 tmplArgs, err := newTmplArgs(op, apidef) 165 if err != nil { 166 return err 167 } 168 return tmplOp.Execute(w, tmplArgs) 169} 170 171var ( 172 // Go keywords that cannot be used as identifiers. 173 // From https://golang.org/ref/spec#Keywords 174 keywords = []string{ 175 "break", "default", "func", "interface", "select", "case", 176 "defer", "go", "map", "struct", "chan", "else", "goto", 177 "package", "switch", "const", "fallthrough", "if", "range", 178 "type", "continue", "for", "import", "return", "var", 179 } 180 181 tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT 182// This file was machine generated by {{.}} 183// 184// WARNING: This generation of wrapper function for TensorFlow ops is in an 185// experimental state. The generated API can change without notice. 186 187package op 188 189import tf "github.com/tensorflow/tensorflow/tensorflow/go" 190 191// optionalAttr is an intentionally un-exported type to hide 192// details of how optional attributes to operations are implemented. 193type optionalAttr map[string]interface{} 194 195func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) { 196 size, err := op.OutputListSize(output) 197 if err != nil { 198 return nil, start, err 199 } 200 list := make([]tf.Output, size) 201 for i := 0; i < size; i++ { 202 list[i] = op.Output(start + i) 203 } 204 return list, start + size, nil 205} 206`)) 207 208 tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{ 209 "MakeComment": makeComment, 210 "GoType": goType, 211 "CamelCase": camelCase, 212 "Identifier": identifier, 213 "IsListArg": isListArg, 214 "IsListAttr": isListAttr, 215 "StripLeadingColon": stripLeadingColon, 216 }).Parse(` 217{{if .OptionalAttrs -}} 218{{/* Type for specifying all optional attributes. */ -}} 219// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}. 220type {{.Op.Name}}Attr func(optionalAttr) 221 222{{range .OptionalAttrs}} 223// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value. 224{{- if .Description}} 225// 226// value: {{MakeComment .Description}} 227{{- end}} 228// If not specified, defaults to {{StripLeadingColon .DefaultValue}} 229{{- if .HasMinimum}} 230// 231// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} 232{{- end}} 233func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { 234 return func(m optionalAttr) { 235 m[{{printf "%q" .Name}}] = value 236 } 237} 238{{end}} 239{{end}} 240 241{{- /* Create a godoc friendly comment. */ -}} 242 243// {{MakeComment .APIDef.Summary}} 244 245{{- with .Op.Deprecation}} 246// 247// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} 248{{- end -}} 249 250{{- with .APIDef.Description}} 251// 252// {{MakeComment .}} 253{{- end -}} 254 255{{- if .DescribeArguments}} 256// 257// Arguments: 258{{- range .InArgsReordered}} 259// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 260{{- end -}} 261{{- range .RequiredAttrs}} 262// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 263{{- end -}} 264{{- end -}} 265 266{{- if (not .Op.OutputArg) }} 267// 268// Returns the created operation. 269{{- else }} 270{{- if .DescribeOutputs}} 271// 272{{- if eq (len .OutArgs) 1 }} 273// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}} 274{{- else }} 275// Returns: 276{{- range .OutArgs}} 277// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}} 278{{- end -}} 279{{- end -}} 280{{- end -}} 281{{- end -}} 282{{- /* 283 284 The function signature. 285 Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang 286*/}} 287func {{.Op.Name}} 288 289{{- /* 290 Fill in input arguments: 291 (1) The Scope 292 (2) All input arguments (which may be either []tf.Output or tf.Output) 293 (3) All required attributes 294 (4) Variadic list of optional attributes 295*/ -}} 296 297(scope *Scope 298{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}} 299{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}} 300{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} 301) 302 303{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}} 304 305{{if .OutArgs -}} 306({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}) 307{{- else -}} 308(o *tf.Operation) 309{{- end }} { 310 if scope.Err() != nil { 311 return 312 } 313 {{if .HasAttrs -}} 314 attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}} 315 {{if .OptionalAttrs -}} 316 for _, a := range optional { 317 a(attrs) 318 } 319 {{end -}} 320 {{end -}} 321 opspec := tf.OpSpec{ 322 Type: {{printf "%q" .Op.Name}}, 323 {{if .InArgs -}} 324 Input: []tf.Input{ 325 {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}} 326 }, 327 {{- end}} 328 {{- if .HasAttrs}} 329 Attrs: attrs, 330 {{- end}} 331 } 332 {{- if .OutArgs}} 333 {{- if .HasListOutput}} 334 op := scope.AddOperation(opspec) 335 if scope.Err() != nil { 336 return 337 } 338 var idx int 339 var err error 340 {{- range $i, $a := .OutArgs}} 341 {{- if $a.IsListArg}} 342 if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { 343 scope.UpdateErr({{printf "%q" $.Op.Name}}, err) 344 return 345 } 346 {{- else }} 347 {{Identifier .RenameTo}} = op.Output(idx) 348 {{- end }}{{- /* if IsListArg */}} 349 {{- end }}{{- /* range .OutArgs */}} 350 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}} 351 {{- else }} 352 op := scope.AddOperation(opspec) 353 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} 354 {{- end }}{{- /* if .HasListOutput */}} 355 {{- else }} 356 return scope.AddOperation(opspec) 357 {{- end }}{{- /* if .OutArgs */}} 358} 359`)) 360) 361 362type attrWrapper struct { 363 op *odpb.OpDef_AttrDef 364 api *adpb.ApiDef_Attr 365} 366 367func (a *attrWrapper) Name() string { return a.api.Name } 368func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } 369func (a *attrWrapper) Description() string { return a.api.Description } 370func (a *attrWrapper) Type() string { return a.op.Type } 371func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } 372func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } 373func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } 374func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } 375 376type argWrapper struct { 377 op *odpb.OpDef_ArgDef 378 api *adpb.ApiDef_Arg 379} 380 381func (a *argWrapper) Name() string { return a.api.Name } 382func (a *argWrapper) RenameTo() string { return a.api.RenameTo } 383func (a *argWrapper) Description() string { return a.api.Description } 384func (a *argWrapper) IsListArg() bool { return isListArg(a.op) } 385 386type tmplArgs struct { 387 Op *odpb.OpDef 388 APIDef *adpb.ApiDef 389 // Op.Attr is split into two categories 390 // (1) Required: These must be specified by the client and are thus 391 // included in the function signature. 392 // (2) Optional: These need not be specified (as they have default 393 // values) and thus do not appear in the function signature. 394 RequiredAttrs []*attrWrapper 395 OptionalAttrs []*attrWrapper 396 InArgs []*argWrapper 397 // Input arguments ordered based on arg_order field of ApiDef. 398 InArgsReordered []*argWrapper 399 OutArgs []*argWrapper 400} 401 402func newTmplArgs(op *odpb.OpDef, apidef *adpb.ApiDef) (*tmplArgs, error) { 403 ret := tmplArgs{Op: op, APIDef: apidef} 404 405 // Setup InArgs field 406 for i, in := range op.InputArg { 407 argCombined := argWrapper{op: in, api: apidef.InArg[i]} 408 ret.InArgs = append(ret.InArgs, &argCombined) 409 } 410 411 // Setup OutArgs field 412 for i, out := range op.OutputArg { 413 argCombined := argWrapper{op: out, api: apidef.OutArg[i]} 414 ret.OutArgs = append(ret.OutArgs, &argCombined) 415 } 416 417 // Setup InArgsReordered field 418 for _, argName := range apidef.ArgOrder { 419 // Find the argument in op.InputArg 420 argIndex := -1 421 for i, in := range op.InputArg { 422 if in.Name == argName { 423 argIndex = i 424 break 425 } 426 } 427 if argIndex == -1 { 428 return nil, fmt.Errorf( 429 "couldn't find argument %s in ApiDef for op %s", 430 argName, op.Name) 431 } 432 argCombined := argWrapper{ 433 op: op.InputArg[argIndex], api: apidef.InArg[argIndex]} 434 ret.InArgsReordered = append(ret.InArgsReordered, &argCombined) 435 } 436 437 if len(op.Attr) == 0 { 438 return &ret, nil 439 } 440 // Attributes related to the InputArg's type are inferred automatically 441 // and are not exposed to the client. 442 inferred := make(map[string]bool) 443 for _, in := range op.InputArg { 444 switch { 445 case in.TypeAttr != "": 446 inferred[in.TypeAttr] = true 447 case in.TypeListAttr != "": 448 inferred[in.TypeListAttr] = true 449 } 450 if in.NumberAttr != "" { 451 inferred[in.NumberAttr] = true 452 } 453 } 454 for i, attr := range op.Attr { 455 if inferred[attr.Name] { 456 continue 457 } 458 attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]} 459 if attr.DefaultValue == nil { 460 ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined) 461 } else { 462 ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined) 463 } 464 } 465 return &ret, nil 466} 467 468func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } 469func (a *tmplArgs) DescribeArguments() bool { 470 for _, arg := range a.InArgs { 471 if arg.Description() != "" { 472 return true 473 } 474 } 475 for _, attr := range a.RequiredAttrs { 476 if attr.Description() != "" { 477 return true 478 } 479 } 480 return false 481 482} 483func (a *tmplArgs) DescribeOutputs() bool { 484 for _, arg := range a.OutArgs { 485 if arg.Description() != "" { 486 return true 487 } 488 } 489 return false 490} 491func (a *tmplArgs) HasListOutput() bool { 492 for _, arg := range a.OutArgs { 493 if arg.IsListArg() { 494 return true 495 } 496 } 497 return false 498} 499 500func makeComment(lines string) string { 501 return strings.Join(strings.SplitAfter(lines, "\n"), "// ") 502} 503 504// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.) 505// to the corresponding type in Go. 506func goType(tfType string) (string, error) { 507 list, tfType := parseTFType(tfType) 508 var gotype string 509 switch tfType { 510 case "int": 511 gotype = "int64" 512 case "float": 513 gotype = "float32" 514 case "bool": 515 gotype = "bool" 516 case "type": 517 gotype = "tf.DataType" 518 case "shape": 519 gotype = "tf.Shape" 520 case "tensor": 521 gotype = "tf.Tensor" 522 case "string": 523 gotype = "string" 524 default: 525 return "", fmt.Errorf("%q is not a recognized DataType", tfType) 526 } 527 if list { 528 gotype = "[]" + gotype 529 } 530 return gotype, nil 531} 532 533func camelCase(snakeCase string) string { 534 words := strings.Split(snakeCase, "_") 535 for i, w := range words { 536 words[i] = strings.ToUpper(string(w[0])) + w[1:] 537 } 538 return strings.Join(words, "") 539} 540 541// identifier creates an identifier for s usable in the generated Go source 542// code. 543// 544// Avoids collisions with keywords and other identifiers used in the generated 545// code. 546func identifier(s string) string { 547 // Identifiers used in the generated code. 548 if s == "tf" || s == "scope" || s == "err" || s == "op" { 549 return s + "_" 550 } 551 for _, k := range keywords { 552 if s == k { 553 // Alternatively, make the first letter upper case. 554 return s + "_" 555 } 556 } 557 return s 558} 559 560func isListArg(argdef *odpb.OpDef_ArgDef) bool { 561 return argdef.TypeListAttr != "" || argdef.NumberAttr != "" 562} 563 564func isListAttr(attrdef *odpb.OpDef_AttrDef) bool { 565 list, _ := parseTFType(attrdef.Type) 566 return list 567} 568 569// stripLeadingColon removes the prefix of the string up to the first colon. 570// 571// This is useful when 's' corresponds to a "oneof" protocol buffer message. 572// For example, consider the protocol buffer message: 573// oneof value { bool b = 1; int64 i = 2; } 574// proto.CompactTextString) will print "b:true", or "i:7" etc. This function 575// strips out the leading "b:" or "i:". 576func stripLeadingColon(m proto.Message) string { 577 o := prototext.MarshalOptions{Multiline: false} 578 x := o.Format(m) 579 y := strings.SplitN(x, ":", 2) 580 if len(y) < 2 { 581 return x 582 } 583 return y[1] 584} 585 586func parseTFType(tfType string) (list bool, typ string) { 587 const ( 588 listPrefix = "list(" 589 listSuffix = ")" 590 ) 591 if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) { 592 return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix) 593 } 594 return false, tfType 595} 596