1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19// #include "tensorflow/c/c_api.h" 20// 21// #include <stdlib.h> 22// #include <string.h> 23// 24// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc, 25// const char* attr_name, 26// const int64_t* flat_dims, 27// const int* num_dims, 28// int num_shapes) { 29// const int64_t** dims = 30// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes); 31// int i = 0; 32// for (i = 0; i < num_shapes; i++) { 33// dims[i] = flat_dims; 34// if (num_dims[i] > 0) { 35// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0. 36// flat_dims += num_dims[i]; 37// } 38// } 39// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes); 40// free(dims); 41// } 42import "C" 43 44import ( 45 "fmt" 46 "io" 47 "runtime" 48 "unsafe" 49) 50 51// Graph represents a computation graph. Graphs may be shared between sessions. 52type Graph struct { 53 c *C.TF_Graph 54} 55 56// The GraphImportOptions struct holds parameters for the ImportWithOptions function. 57type GraphImportOptions struct { 58 // Node prefix 59 Prefix string 60 61 // Execution device 62 Device string 63 64 // inputMapping defines a mapping between Outputs in the graph 65 // and Outputs they should be replaced with. 66 inputMapping map[struct { 67 Name string 68 Index int 69 }]Output 70 71 // TODO: extend this structure to support more options from TF_ImportGraphDefOptions 72} 73 74// AddInputMapping adds a mapping between an Output in the imported graph 75// and an Output in the destination graph that it should be replaced with, 76// where src:srcIndex is the name of the Operation and Output index to 77// replace and dst is the output to replace it with. 78func (o *GraphImportOptions) AddInputMapping(src string, srcIndex int, dst Output) { 79 if o.inputMapping == nil { 80 o.inputMapping = make(map[struct { 81 Name string 82 Index int 83 }]Output) 84 } 85 o.inputMapping[struct { 86 Name string 87 Index int 88 }{src, srcIndex}] = dst 89} 90 91// NewGraph returns a new Graph. 92func NewGraph() *Graph { 93 g := &Graph{C.TF_NewGraph()} 94 runtime.SetFinalizer(g, (*Graph).finalizer) 95 return g 96} 97 98func (g *Graph) finalizer() { 99 C.TF_DeleteGraph(g.c) 100} 101 102// WriteTo writes out a serialized representation of g to w. 103// 104// Implements the io.WriterTo interface. 105func (g *Graph) WriteTo(w io.Writer) (int64, error) { 106 buf := C.TF_NewBuffer() 107 defer C.TF_DeleteBuffer(buf) 108 status := newStatus() 109 C.TF_GraphToGraphDef(g.c, buf, status.c) 110 if err := status.Err(); err != nil { 111 return 0, err 112 } 113 if buf.length > (1 << 30) { 114 // For very large graphs, the writes can be chunked. 115 // Punt on that for now. 116 return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated") 117 } 118 // A []byte slice backed by C memory. 119 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 120 length := int(buf.length) 121 var slice []byte 122 if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 { 123 slice = (*[1<<50 - 1]byte)(unsafe.Pointer(buf.data))[:length:length] 124 } else { 125 slice = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length] 126 } 127 n, err := w.Write(slice) 128 return int64(n), err 129} 130 131// ImportWithOptions imports the nodes and edges from a serialized representation of 132// another Graph into g. 133// 134// Multiple options can be specified for the newly imported nodes. 135func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error { 136 cprefix := C.CString(options.Prefix) 137 defer C.free(unsafe.Pointer(cprefix)) 138 139 opts := C.TF_NewImportGraphDefOptions() 140 defer C.TF_DeleteImportGraphDefOptions(opts) 141 C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix) 142 143 if len(options.Device) != 0 { 144 cdev := C.CString(options.Device) 145 defer C.free(unsafe.Pointer(cdev)) 146 C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev) 147 } 148 149 for src, dst := range options.inputMapping { 150 cSrcName := C.CString(src.Name) 151 C.TF_ImportGraphDefOptionsAddInputMapping(opts, cSrcName, C.int(src.Index), dst.c()) 152 C.free(unsafe.Pointer(cSrcName)) 153 } 154 155 buf := C.TF_NewBuffer() 156 defer C.TF_DeleteBuffer(buf) 157 buf.length = C.size_t(len(def)) 158 buf.data = C.CBytes(def) 159 if buf.data == nil { 160 return fmt.Errorf("unable to allocate memory") 161 } 162 defer C.free(buf.data) 163 164 status := newStatus() 165 166 C.TF_GraphImportGraphDef(g.c, buf, opts, status.c) 167 if err := status.Err(); err != nil { 168 return err 169 } 170 171 return nil 172} 173 174// Import imports the nodes and edges from a serialized representation of 175// another Graph into g. 176// 177// Names of imported nodes will be prefixed with prefix. 178func (g *Graph) Import(def []byte, prefix string) error { 179 return g.ImportWithOptions(def, GraphImportOptions{Prefix: prefix}) 180} 181 182// Operation returns the Operation named name in the Graph, or nil if no such 183// operation is present. 184func (g *Graph) Operation(name string) *Operation { 185 cname := C.CString(name) 186 defer C.free(unsafe.Pointer(cname)) 187 cop := C.TF_GraphOperationByName(g.c, cname) 188 if cop == nil { 189 return nil 190 } 191 return &Operation{cop, g} 192} 193 194// Operations returns a list of all operations in the graph 195func (g *Graph) Operations() []Operation { 196 var pos C.size_t 197 ops := []Operation{} 198 for { 199 cop := C.TF_GraphNextOperation(g.c, &pos) 200 if cop == nil { 201 break 202 } 203 ops = append(ops, Operation{cop, g}) 204 } 205 return ops 206} 207 208// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y 209// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc. 210// 211// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute 212// these gradients. 213func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) { 214 var ( 215 cprefix *C.char 216 217 cy = make([]C.TF_Output, len(y)) 218 cx = make([]C.TF_Output, len(x)) 219 cdx = make([]C.TF_Output, len(dx)) 220 cdy = make([]C.TF_Output, len(x)) 221 222 pcy *C.TF_Output 223 pcx *C.TF_Output 224 pcdx *C.TF_Output 225 pcdy *C.TF_Output 226 227 status = newStatus() 228 ) 229 230 if len(y) > 0 { 231 pcy = &cy[0] 232 for i, o := range y { 233 cy[i] = o.c() 234 } 235 } 236 if len(x) > 0 { 237 pcx = &cx[0] 238 for i, o := range x { 239 cx[i] = o.c() 240 } 241 pcdy = &cdy[0] 242 } 243 if len(dx) > 0 { 244 pcdx = &cdx[0] 245 for i, o := range dx { 246 cdx[i] = o.c() 247 } 248 } 249 250 // If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not "" 251 if len(prefix) != 0 { 252 cprefix = C.CString(prefix) 253 defer C.free(unsafe.Pointer(cprefix)) 254 } 255 256 C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy) 257 258 if err := status.Err(); err != nil { 259 return nil, err 260 } 261 dy := make([]Output, len(x)) 262 for i, co := range cdy { 263 op := &Operation{co.oper, g} 264 dy[i] = Output{op, int(co.index)} 265 } 266 267 return dy, nil 268} 269 270// OpSpec is the specification of an Operation to be added to a Graph 271// (using Graph.AddOperation). 272type OpSpec struct { 273 // Type of the operation (e.g., "Add", "MatMul"). 274 Type string 275 276 // Name by which the added operation will be referred to in the Graph. 277 // If omitted, defaults to Type. 278 Name string 279 280 // Inputs to this operation, which in turn must be outputs 281 // of other operations already added to the Graph. 282 // 283 // An operation may have multiple inputs with individual inputs being 284 // either a single tensor produced by another operation or a list of 285 // tensors produced by multiple operations. For example, the "Concat" 286 // operation takes two inputs: (1) the dimension along which to 287 // concatenate and (2) a list of tensors to concatenate. Thus, for 288 // Concat, len(Input) must be 2, with the first element being an Output 289 // and the second being an OutputList. 290 Input []Input 291 292 // Map from attribute name to its value that will be attached to this 293 // operation. 294 Attrs map[string]interface{} 295 296 // Operations that must be executed before executing the operation 297 // being added. 298 ControlDependencies []*Operation 299 300 // The device on which the operation should be executed. 301 // If omitted, an appropriate device will automatically be selected. 302 // 303 // For example, if set of "/device:GPU:0", then the operation will 304 // execute on GPU #0. 305 Device string 306 307 // Other possible fields: ColocateWith. 308} 309 310// AddOperation adds an operation to g. 311func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { 312 if args.Name == "" { 313 args.Name = args.Type 314 } 315 cname := C.CString(args.Name) 316 ctype := C.CString(args.Type) 317 cdesc := C.TF_NewOperation(g.c, ctype, cname) 318 C.free(unsafe.Pointer(cname)) 319 C.free(unsafe.Pointer(ctype)) 320 321 for _, in := range args.Input { 322 switch in := in.(type) { 323 case Output: 324 C.TF_AddInput(cdesc, in.c()) 325 case OutputList: 326 size := len(in) 327 list := make([]C.TF_Output, size) 328 for i, v := range in { 329 list[i] = v.c() 330 } 331 if size > 0 { 332 C.TF_AddInputList(cdesc, &list[0], C.int(size)) 333 } else { 334 C.TF_AddInputList(cdesc, nil, 0) 335 } 336 } 337 } 338 for _, in := range args.ControlDependencies { 339 C.TF_AddControlInput(cdesc, in.c) 340 } 341 status := newStatus() 342 for name, value := range args.Attrs { 343 if err := setAttr(cdesc, status, name, value); err != nil { 344 // Memory leak here as the TF_OperationDescription 345 // object will not be cleaned up. At the time of this 346 // writing, this was next to impossible since it 347 // required value to be a string tensor with 348 // incorrectly encoded strings. Given this rarity, live 349 // with the memory leak. If it becomes a real problem, 350 // consider adding a TF_DeleteOperationDescription 351 // function to the C API. 352 return nil, fmt.Errorf("%v (memory will be leaked)", err) 353 } 354 } 355 if len(args.Device) > 0 { 356 cdevice := C.CString(args.Device) 357 C.TF_SetDevice(cdesc, cdevice) 358 C.free(unsafe.Pointer(cdevice)) 359 } 360 c := C.TF_FinishOperation(cdesc, status.c) 361 if err := status.Err(); err != nil { 362 return nil, err 363 } 364 return &Operation{c, g}, nil 365} 366 367func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { 368 cAttrName := C.CString(name) 369 defer C.free(unsafe.Pointer(cAttrName)) 370 switch value := value.(type) { 371 case string: 372 cstr := C.CString(value) 373 C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value))) 374 C.free(unsafe.Pointer(cstr)) 375 case []string: 376 size := len(value) 377 list := make([]unsafe.Pointer, size) 378 lens := make([]C.size_t, size) 379 for i, s := range value { 380 list[i] = unsafe.Pointer(C.CString(s)) 381 lens[i] = C.size_t(len(s)) 382 } 383 if size > 0 { 384 C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) 385 } else { 386 C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0) 387 } 388 for _, s := range list { 389 C.free(s) 390 } 391 case int64: 392 C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) 393 case []int64: 394 size := len(value) 395 list := make([]C.int64_t, size) 396 for i, v := range value { 397 list[i] = C.int64_t(v) 398 } 399 if size > 0 { 400 C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) 401 } else { 402 C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0) 403 } 404 case float32: 405 C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) 406 case []float32: 407 size := len(value) 408 list := make([]C.float, size) 409 for i, v := range value { 410 list[i] = C.float(v) 411 } 412 if size > 0 { 413 C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) 414 } else { 415 C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0) 416 } 417 case bool: 418 v := C.uchar(0) 419 if value { 420 v = 1 421 } 422 C.TF_SetAttrBool(cdesc, cAttrName, v) 423 case []bool: 424 size := len(value) 425 list := make([]C.uchar, size) 426 for i, v := range value { 427 if v { 428 list[i] = 1 429 } 430 } 431 if size > 0 { 432 C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) 433 } else { 434 C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0) 435 } 436 case DataType: 437 C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) 438 case []DataType: 439 var list *C.TF_DataType 440 if len(value) > 0 { 441 list = (*C.TF_DataType)(&value[0]) 442 } 443 C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) 444 case *Tensor: 445 C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) 446 if err := status.Err(); err != nil { 447 return fmt.Errorf("bad value for attribute %q: %v", name, err) 448 } 449 case []*Tensor: 450 size := len(value) 451 list := make([]*C.TF_Tensor, size) 452 for i, v := range value { 453 list[i] = v.c 454 } 455 var plist **C.TF_Tensor 456 if size > 0 { 457 plist = &list[0] 458 } 459 C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c) 460 if err := status.Err(); err != nil { 461 return fmt.Errorf("bad value for attribute %q: %v", name, err) 462 } 463 case Shape: 464 ndims := C.int(value.NumDimensions()) 465 var dimsp *C.int64_t 466 if ndims > 0 { 467 dims := make([]C.int64_t, ndims) 468 for i, d := range value.dims { 469 dims[i] = C.int64_t(d) 470 } 471 dimsp = &dims[0] 472 } 473 C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) 474 case []Shape: 475 if len(value) == 0 { 476 C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) 477 } else { 478 var flatDims []C.int64_t 479 ndims := make([]C.int, len(value)) 480 for i, s := range value { 481 nd := s.NumDimensions() 482 ndims[i] = C.int(nd) 483 for _, d := range s.dims { 484 flatDims = append(flatDims, C.int64_t(d)) 485 } 486 } 487 var flatDimsp *C.int64_t 488 if len(flatDims) > 0 { 489 flatDimsp = &flatDims[0] 490 } 491 C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value))) 492 } 493 default: 494 return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) 495 } 496 return nil 497} 498 499type LibraryHandler struct { 500 cptr *C.TF_Library 501} 502 503// Load library content into current context, useful to load ops implementation into non-monolithic TF build. Returns LibraryHandler or nil and error 504func LoadLibrary(path string) (*LibraryHandler, error) { 505 status := newStatus() 506 507 cpath := C.CString(path) 508 defer C.free(unsafe.Pointer(cpath)) 509 cptr := C.TF_LoadLibrary(cpath, status.c) 510 if cptr == nil || status.Code() != C.TF_OK { 511 return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String()) 512 } 513 514 lh := &LibraryHandler{ 515 cptr: cptr, 516 } 517 518 runtime.SetFinalizer(lh, (*LibraryHandler).free) 519 return lh, nil 520} 521 522func (lh *LibraryHandler) free() { 523 if lh == nil || lh.cptr == nil { 524 return 525 } 526 527 C.TF_DeleteLibraryHandle(lh.cptr) 528} 529