1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17// Package internal generates Go source code with functions for TensorFlow operations. 18// 19// The basic outline of the generated API is as follows: 20// 21// - One function for each TensorFlow operation 22// - The arguments to the function are the inputs and required attributes of the operation 23// - The function returns the outputs 24// - A function is also generated for each optional attribute of the operation. 25// 26// There is a possibility that there are name collisions between the functions 27// generated for ops and the functions generated for optional attributes. For 28// now, we ignore those, but will need to revisit if a collision is actually 29// encountered. 30package internal 31 32/* 33#include <stdlib.h> 34 35#include "tensorflow/c/c_api.h" 36*/ 37import "C" 38 39import ( 40 "fmt" 41 "io" 42 "io/ioutil" 43 "path" 44 "reflect" 45 "strings" 46 "text/template" 47 "unsafe" 48 49 "github.com/golang/protobuf/proto" 50 pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework" 51) 52 53// GenerateFunctionsForRegisteredOps writes a Go source code file to w 54// containing functions for each TensorFlow operation registered in the address 55// space of the calling process. 56// apidefDirs should be a contain of directories containing api_def_*.pbtxt 57// files to load. 58func GenerateFunctionsForRegisteredOps( 59 w io.Writer, apidefDirs []string) error { 60 ops, apimap, err := registeredOps() 61 if err != nil { 62 return err 63 } 64 for _, dir := range apidefDirs { 65 if err = updateAPIDefs(apimap, dir); err != nil { 66 return err 67 } 68 } 69 return generateFunctionsForOps(w, ops, apimap) 70} 71 72func registeredOps() (*pb.OpList, *apiDefMap, error) { 73 buf := C.TF_GetAllOpList() 74 defer C.TF_DeleteBuffer(buf) 75 var ( 76 list = new(pb.OpList) 77 size = int(buf.length) 78 // A []byte backed by C memory. 79 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 80 data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] 81 err = proto.Unmarshal(data, list) 82 ) 83 if err != nil { 84 return nil, nil, err 85 } 86 apimap, err := newAPIDefMap(list) 87 return list, apimap, err 88} 89 90func updateAPIDefs(m *apiDefMap, dir string) error { 91 files, err := ioutil.ReadDir(dir) 92 if err != nil { 93 return err 94 } 95 for _, file := range files { 96 data, err := ioutil.ReadFile(path.Join(dir, file.Name())) 97 if err != nil { 98 return fmt.Errorf("failed to read %q: %v", file.Name(), err) 99 } 100 if err = m.Put(string(data)); err != nil { 101 return fmt.Errorf("failed to process %q: %v", file.Name(), err) 102 } 103 } 104 return nil 105} 106 107func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error { 108 thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() 109 if err := tmplHeader.Execute(w, thisPackage); err != nil { 110 return err 111 } 112 blacklist := map[string]bool{ 113 "Const": true, 114 "PyFunc": true, 115 "PyFuncStateless": true, 116 } 117 for _, op := range ops.Op { 118 if blacklist[op.Name] { 119 continue 120 } 121 apidef, err := apimap.Get(op.Name) 122 if err != nil { 123 return err 124 } 125 if err := generateFunctionForOp(w, op, apidef); err != nil { 126 return err 127 } 128 } 129 return nil 130} 131 132func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error { 133 if strings.HasPrefix(op.Name, "_") { // Internal operation 134 return nil 135 } 136 // Ignore operations where the Go types corresponding to the TensorFlow 137 // type haven't been worked out (such as "func"s). 138 for _, a := range op.Attr { 139 if _, err := goType(a.Type); err != nil { 140 return nil 141 } 142 } 143 // Also, haven't figured out reference types yet, so ignore those too. 144 for _, a := range op.InputArg { 145 if a.IsRef { 146 return nil 147 } 148 } 149 for _, a := range op.OutputArg { 150 if a.IsRef { 151 return nil 152 } 153 } 154 if apidef.Summary == "" { 155 // Undocumented operation, perhaps a sign of not being ready to 156 // export. 157 return nil 158 } 159 tmplArgs, err := newTmplArgs(op, apidef) 160 if err != nil { 161 return err 162 } 163 return tmplOp.Execute(w, tmplArgs) 164} 165 166var ( 167 // Go keywords that cannot be used as identifiers. 168 // From https://golang.org/ref/spec#Keywords 169 keywords = []string{ 170 "break", "default", "func", "interface", "select", "case", 171 "defer", "go", "map", "struct", "chan", "else", "goto", 172 "package", "switch", "const", "fallthrough", "if", "range", 173 "type", "continue", "for", "import", "return", "var", 174 } 175 176 tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT 177// This file was machine generated by {{.}} 178// 179// WARNING: This generation of wrapper function for TensorFlow ops is in an 180// experimental state. The generated API can change without notice. 181 182package op 183 184import tf "github.com/tensorflow/tensorflow/tensorflow/go" 185 186// optionalAttr is an intentionally un-exported type to hide 187// details of how optional attributes to operations are implemented. 188type optionalAttr map[string]interface{} 189 190func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) { 191 size, err := op.OutputListSize(output) 192 if err != nil { 193 return nil, start, err 194 } 195 list := make([]tf.Output, size) 196 for i := 0; i < size; i++ { 197 list[i] = op.Output(start + i) 198 } 199 return list, start + size, nil 200} 201`)) 202 203 tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{ 204 "MakeComment": makeComment, 205 "GoType": goType, 206 "CamelCase": camelCase, 207 "Identifier": identifier, 208 "IsListArg": isListArg, 209 "IsListAttr": isListAttr, 210 "StripLeadingColon": stripLeadingColon, 211 }).Parse(` 212{{if .OptionalAttrs -}} 213{{/* Type for specifying all optional attributes. */ -}} 214// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}. 215type {{.Op.Name}}Attr func(optionalAttr) 216 217{{range .OptionalAttrs}} 218// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value. 219{{- if .Description}} 220// 221// value: {{MakeComment .Description}} 222{{- end}} 223// If not specified, defaults to {{StripLeadingColon .DefaultValue}} 224{{- if .HasMinimum}} 225// 226// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} 227{{- end}} 228func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { 229 return func(m optionalAttr) { 230 m[{{printf "%q" .Name}}] = value 231 } 232} 233{{end}} 234{{end}} 235 236{{- /* Create a godoc friendly comment. */ -}} 237 238// {{MakeComment .APIDef.Summary}} 239 240{{- with .Op.Deprecation}} 241// 242// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} 243{{- end -}} 244 245{{- with .APIDef.Description}} 246// 247// {{MakeComment .}} 248{{- end -}} 249 250{{- if .DescribeArguments}} 251// 252// Arguments: 253{{- range .InArgsReordered}} 254// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 255{{- end -}} 256{{- range .RequiredAttrs}} 257// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} 258{{- end -}} 259{{- end -}} 260 261{{- if (not .Op.OutputArg) }} 262// 263// Returns the created operation. 264{{- else }} 265{{- if .DescribeOutputs}} 266// 267{{- if ((len .OutArgs) eq 1) }} 268// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}} 269{{- else }} 270// Returns: 271{{- range .OutArgs}} 272// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}} 273{{- end -}} 274{{- end -}} 275{{- end -}} 276{{- end -}} 277{{- /* 278 279 The function signature. 280 Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang 281*/}} 282func {{.Op.Name}} 283 284{{- /* 285 Fill in input arguments: 286 (1) The Scope 287 (2) All input arguments (which may be either []tf.Output or tf.Output) 288 (3) All required attributes 289 (4) Variadic list of optional attributes 290*/ -}} 291 292(scope *Scope 293{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}} 294{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}} 295{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} 296) 297 298{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}} 299 300{{if .OutArgs -}} 301({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}) 302{{- else -}} 303(o *tf.Operation) 304{{- end }} { 305 if scope.Err() != nil { 306 return 307 } 308 {{if .HasAttrs -}} 309 attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}} 310 {{if .OptionalAttrs -}} 311 for _, a := range optional { 312 a(attrs) 313 } 314 {{end -}} 315 {{end -}} 316 opspec := tf.OpSpec{ 317 Type: {{printf "%q" .Op.Name}}, 318 {{if .InArgs -}} 319 Input: []tf.Input{ 320 {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}} 321 }, 322 {{- end}} 323 {{- if .HasAttrs}} 324 Attrs: attrs, 325 {{- end}} 326 } 327 {{- if .OutArgs}} 328 {{- if .HasListOutput}} 329 op := scope.AddOperation(opspec) 330 if scope.Err() != nil { 331 return 332 } 333 var idx int 334 var err error 335 {{- range $i, $a := .OutArgs}} 336 {{- if $a.IsListArg}} 337 if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { 338 scope.UpdateErr({{printf "%q" $.Op.Name}}, err) 339 return 340 } 341 {{- else }} 342 {{Identifier .RenameTo}} = op.Output(idx) 343 {{- end }}{{- /* if IsListArg */}} 344 {{- end }}{{- /* range .OutArgs */}} 345 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}} 346 {{- else }} 347 op := scope.AddOperation(opspec) 348 return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} 349 {{- end }}{{- /* if .HasListOutput */}} 350 {{- else }} 351 return scope.AddOperation(opspec) 352 {{- end }}{{- /* if .OutArgs */}} 353} 354`)) 355) 356 357type attrWrapper struct { 358 op *pb.OpDef_AttrDef 359 api *pb.ApiDef_Attr 360} 361 362func (a *attrWrapper) Name() string { return a.api.Name } 363func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } 364func (a *attrWrapper) Description() string { return a.api.Description } 365func (a *attrWrapper) Type() string { return a.op.Type } 366func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } 367func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } 368func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } 369func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } 370 371type argWrapper struct { 372 op *pb.OpDef_ArgDef 373 api *pb.ApiDef_Arg 374} 375 376func (a *argWrapper) Name() string { return a.api.Name } 377func (a *argWrapper) RenameTo() string { return a.api.RenameTo } 378func (a *argWrapper) Description() string { return a.api.Description } 379func (a *argWrapper) IsListArg() bool { return isListArg(a.op) } 380 381type tmplArgs struct { 382 Op *pb.OpDef 383 APIDef *pb.ApiDef 384 // Op.Attr is split into two categories 385 // (1) Required: These must be specified by the client and are thus 386 // included in the function signature. 387 // (2) Optional: These need not be specified (as they have default 388 // values) and thus do not appear in the function signature. 389 RequiredAttrs []*attrWrapper 390 OptionalAttrs []*attrWrapper 391 InArgs []*argWrapper 392 // Input arguments ordered based on arg_order field of ApiDef. 393 InArgsReordered []*argWrapper 394 OutArgs []*argWrapper 395} 396 397func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) { 398 ret := tmplArgs{Op: op, APIDef: apidef} 399 400 // Setup InArgs field 401 for i, in := range op.InputArg { 402 argCombined := argWrapper{op: in, api: apidef.InArg[i]} 403 ret.InArgs = append(ret.InArgs, &argCombined) 404 } 405 406 // Setup OutArgs field 407 for i, out := range op.OutputArg { 408 argCombined := argWrapper{op: out, api: apidef.OutArg[i]} 409 ret.OutArgs = append(ret.OutArgs, &argCombined) 410 } 411 412 // Setup InArgsReordered field 413 for _, argName := range apidef.ArgOrder { 414 // Find the argument in op.InputArg 415 argIndex := -1 416 for i, in := range op.InputArg { 417 if in.Name == argName { 418 argIndex = i 419 break 420 } 421 } 422 if argIndex == -1 { 423 return nil, fmt.Errorf( 424 "couldn't find argument %s in ApiDef for op %s", 425 argName, op.Name) 426 } 427 argCombined := argWrapper{ 428 op: op.InputArg[argIndex], api: apidef.InArg[argIndex]} 429 ret.InArgsReordered = append(ret.InArgsReordered, &argCombined) 430 } 431 432 if len(op.Attr) == 0 { 433 return &ret, nil 434 } 435 // Attributes related to the InputArg's type are inferred automatically 436 // and are not exposed to the client. 437 inferred := make(map[string]bool) 438 for _, in := range op.InputArg { 439 switch { 440 case in.TypeAttr != "": 441 inferred[in.TypeAttr] = true 442 case in.TypeListAttr != "": 443 inferred[in.TypeListAttr] = true 444 } 445 if in.NumberAttr != "" { 446 inferred[in.NumberAttr] = true 447 } 448 } 449 for i, attr := range op.Attr { 450 if inferred[attr.Name] { 451 continue 452 } 453 attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]} 454 if attr.DefaultValue == nil { 455 ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined) 456 } else { 457 ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined) 458 } 459 } 460 return &ret, nil 461} 462 463func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } 464func (a *tmplArgs) DescribeArguments() bool { 465 for _, arg := range a.InArgs { 466 if arg.Description() != "" { 467 return true 468 } 469 } 470 for _, attr := range a.RequiredAttrs { 471 if attr.Description() != "" { 472 return true 473 } 474 } 475 return false 476 477} 478func (a *tmplArgs) DescribeOutputs() bool { 479 for _, arg := range a.OutArgs { 480 if arg.Description() != "" { 481 return true 482 } 483 } 484 return false 485} 486func (a *tmplArgs) HasListOutput() bool { 487 for _, arg := range a.OutArgs { 488 if arg.IsListArg() { 489 return true 490 } 491 } 492 return false 493} 494 495func makeComment(lines string) string { 496 return strings.Join(strings.SplitAfter(lines, "\n"), "// ") 497} 498 499// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.) 500// to the corresponding type in Go. 501func goType(tfType string) (string, error) { 502 list, tfType := parseTFType(tfType) 503 var gotype string 504 switch tfType { 505 case "int": 506 gotype = "int64" 507 case "float": 508 gotype = "float32" 509 case "bool": 510 gotype = "bool" 511 case "type": 512 gotype = "tf.DataType" 513 case "shape": 514 gotype = "tf.Shape" 515 case "tensor": 516 gotype = "tf.Tensor" 517 case "string": 518 gotype = "string" 519 default: 520 return "", fmt.Errorf("%q is not a recognized DataType", tfType) 521 } 522 if list { 523 gotype = "[]" + gotype 524 } 525 return gotype, nil 526} 527 528func camelCase(snakeCase string) string { 529 words := strings.Split(snakeCase, "_") 530 for i, w := range words { 531 words[i] = strings.ToUpper(string(w[0])) + w[1:] 532 } 533 return strings.Join(words, "") 534} 535 536// identifier creates an identifier for s usable in the generated Go source 537// code. 538// 539// Avoids collisions with keywords and other identifiers used in the generated 540// code. 541func identifier(s string) string { 542 // Identifiers used in the generated code. 543 if s == "tf" || s == "scope" || s == "err" || s == "op" { 544 return s + "_" 545 } 546 for _, k := range keywords { 547 if s == k { 548 // Alternatively, make the first letter upper case. 549 return s + "_" 550 } 551 } 552 return s 553} 554 555func isListArg(argdef *pb.OpDef_ArgDef) bool { 556 return argdef.TypeListAttr != "" || argdef.NumberAttr != "" 557} 558 559func isListAttr(attrdef *pb.OpDef_AttrDef) bool { 560 list, _ := parseTFType(attrdef.Type) 561 return list 562} 563 564// stripLeadingColon removes the prefix of the string up to the first colon. 565// 566// This is useful when 's' corresponds to a "oneof" protocol buffer message. 567// For example, consider the protocol buffer message: 568// oneof value { bool b = 1; int64 i = 2; } 569// String() on a Go corresponding object (using proto.CompactTextString) will 570// print "b:true", or "i:7" etc. This function strips out the leading "b:" or 571// "i:". 572func stripLeadingColon(s fmt.Stringer) string { 573 x := s.String() 574 y := strings.SplitN(x, ":", 2) 575 if len(y) < 2 { 576 return x 577 } 578 return y[1] 579} 580 581func parseTFType(tfType string) (list bool, typ string) { 582 const ( 583 listPrefix = "list(" 584 listSuffix = ")" 585 ) 586 if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) { 587 return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix) 588 } 589 return false, tfType 590} 591