1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19// #include "tensorflow/c/c_api.h" 20// 21// #include <stdlib.h> 22// #include <string.h> 23// 24// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc, 25// const char* attr_name, 26// const int64_t* flat_dims, 27// const int* num_dims, 28// int num_shapes) { 29// const int64_t** dims = 30// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes); 31// int i = 0; 32// for (i = 0; i < num_shapes; i++) { 33// dims[i] = flat_dims; 34// if (num_dims[i] > 0) { 35// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0. 36// flat_dims += num_dims[i]; 37// } 38// } 39// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes); 40// free(dims); 41// } 42import "C" 43 44import ( 45 "fmt" 46 "io" 47 "runtime" 48 "unsafe" 49) 50 51// Graph represents a computation graph. Graphs may be shared between sessions. 52type Graph struct { 53 c *C.TF_Graph 54} 55 56// Graph execution options 57type GraphImportOptions struct { 58 // Node prefix 59 Prefix string 60 61 // Execution device 62 Device string 63 64 // TODO: extend this structure to support more options from TF_ImportGraphDefOptions 65} 66 67// NewGraph returns a new Graph. 68func NewGraph() *Graph { 69 g := &Graph{C.TF_NewGraph()} 70 runtime.SetFinalizer(g, (*Graph).finalizer) 71 return g 72} 73 74func (g *Graph) finalizer() { 75 C.TF_DeleteGraph(g.c) 76} 77 78// WriteTo writes out a serialized representation of g to w. 79// 80// Implements the io.WriterTo interface. 81func (g *Graph) WriteTo(w io.Writer) (int64, error) { 82 buf := C.TF_NewBuffer() 83 defer C.TF_DeleteBuffer(buf) 84 status := newStatus() 85 C.TF_GraphToGraphDef(g.c, buf, status.c) 86 if err := status.Err(); err != nil { 87 return 0, err 88 } 89 if buf.length > (1 << 30) { 90 // For very large graphs, the writes can be chunked. 91 // Punt on that for now. 92 return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated") 93 } 94 // A []byte slice backed by C memory. 95 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 96 length := int(buf.length) 97 slice := (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length] 98 n, err := w.Write(slice) 99 return int64(n), err 100} 101 102// ImportWithOptions imports the nodes and edges from a serialized representation of 103// another Graph into g. 104// 105// Multiple options can be specified for the newly imported nodes. 106func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error { 107 cprefix := C.CString(options.Prefix) 108 defer C.free(unsafe.Pointer(cprefix)) 109 110 opts := C.TF_NewImportGraphDefOptions() 111 defer C.TF_DeleteImportGraphDefOptions(opts) 112 C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix) 113 114 if len(options.Device) != 0 { 115 // TODO(ashankar): Remove this error and uncomment below 116 // when a release of the C library which includes 117 // https://github.com/tensorflow/tensorflow/commit/e0af5ac53e5a8ad9b07cdd5738c0a8e12f938c4e 118 // has been made. 119 // See https://github.com/tensorflow/tensorflow/issues/23257 120 return fmt.Errorf("GraphImportOptions.Device is only supported with the TensorFlow C library versions after 1.12 (or built from master). See https://github.com/tensorflow/tensorflow/issues/23257") 121 /* 122 cdev := C.CString(options.Device) 123 defer C.free(unsafe.Pointer(cdev)) 124 C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev) 125 */ 126 } 127 128 buf := C.TF_NewBuffer() 129 defer C.TF_DeleteBuffer(buf) 130 // Would have preferred to use C.CBytes, but that does not play well 131 // with "go vet" till https://github.com/golang/go/issues/17201 is 132 // resolved. 133 buf.length = C.size_t(len(def)) 134 buf.data = C.malloc(buf.length) 135 if buf.data == nil { 136 return fmt.Errorf("unable to allocate memory") 137 } 138 defer C.free(buf.data) 139 C.memcpy(buf.data, unsafe.Pointer(&def[0]), buf.length) 140 141 status := newStatus() 142 143 C.TF_GraphImportGraphDef(g.c, buf, opts, status.c) 144 if err := status.Err(); err != nil { 145 return err 146 } 147 148 return nil 149} 150 151// Import imports the nodes and edges from a serialized representation of 152// another Graph into g. 153// 154// Names of imported nodes will be prefixed with prefix. 155func (g *Graph) Import(def []byte, prefix string) error { 156 return g.ImportWithOptions(def, GraphImportOptions{Prefix: prefix}) 157} 158 159// Operation returns the Operation named name in the Graph, or nil if no such 160// operation is present. 161func (g *Graph) Operation(name string) *Operation { 162 cname := C.CString(name) 163 defer C.free(unsafe.Pointer(cname)) 164 cop := C.TF_GraphOperationByName(g.c, cname) 165 if cop == nil { 166 return nil 167 } 168 return &Operation{cop, g} 169} 170 171// Operations returns a list of all operations in the graph 172func (g *Graph) Operations() []Operation { 173 var pos C.size_t = 0 174 ops := []Operation{} 175 for { 176 cop := C.TF_GraphNextOperation(g.c, &pos) 177 if cop == nil { 178 break 179 } 180 ops = append(ops, Operation{cop, g}) 181 } 182 return ops 183} 184 185// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y 186// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc. 187// 188// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute 189// these gradients. 190func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) { 191 var ( 192 cprefix *C.char 193 194 cy = make([]C.TF_Output, len(y)) 195 cx = make([]C.TF_Output, len(x)) 196 cdx = make([]C.TF_Output, len(dx)) 197 cdy = make([]C.TF_Output, len(x)) 198 199 pcy *C.TF_Output 200 pcx *C.TF_Output 201 pcdx *C.TF_Output 202 pcdy *C.TF_Output 203 204 status = newStatus() 205 ) 206 207 if len(y) > 0 { 208 pcy = &cy[0] 209 for i, o := range y { 210 cy[i] = o.c() 211 } 212 } 213 if len(x) > 0 { 214 pcx = &cx[0] 215 for i, o := range x { 216 cx[i] = o.c() 217 } 218 pcdy = &cdy[0] 219 } 220 if len(dx) > 0 { 221 pcdx = &cdx[0] 222 for i, o := range dx { 223 cdx[i] = o.c() 224 } 225 } 226 227 // If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not "" 228 if len(prefix) != 0 { 229 cprefix = C.CString(prefix) 230 defer C.free(unsafe.Pointer(cprefix)) 231 } 232 233 C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy) 234 235 if err := status.Err(); err != nil { 236 return nil, err 237 } 238 dy := make([]Output, len(x)) 239 for i, co := range cdy { 240 op := &Operation{co.oper, g} 241 dy[i] = Output{op, int(co.index)} 242 } 243 244 return dy, nil 245} 246 247// OpSpec is the specification of an Operation to be added to a Graph 248// (using Graph.AddOperation). 249type OpSpec struct { 250 // Type of the operation (e.g., "Add", "MatMul"). 251 Type string 252 253 // Name by which the added operation will be referred to in the Graph. 254 // If omitted, defaults to Type. 255 Name string 256 257 // Inputs to this operation, which in turn must be outputs 258 // of other operations already added to the Graph. 259 // 260 // An operation may have multiple inputs with individual inputs being 261 // either a single tensor produced by another operation or a list of 262 // tensors produced by multiple operations. For example, the "Concat" 263 // operation takes two inputs: (1) the dimension along which to 264 // concatenate and (2) a list of tensors to concatenate. Thus, for 265 // Concat, len(Input) must be 2, with the first element being an Output 266 // and the second being an OutputList. 267 Input []Input 268 269 // Map from attribute name to its value that will be attached to this 270 // operation. 271 Attrs map[string]interface{} 272 273 // Operations that must be executed before executing the operation 274 // being added. 275 ControlDependencies []*Operation 276 277 // The device on which the operation should be executed. 278 // If omitted, an appropriate device will automatically be selected. 279 // 280 // For example, if set of "/device:GPU:0", then the operation will 281 // execute on GPU #0. 282 Device string 283 284 // Other possible fields: ColocateWith. 285} 286 287// AddOperation adds an operation to g. 288func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { 289 if args.Name == "" { 290 args.Name = args.Type 291 } 292 cname := C.CString(args.Name) 293 ctype := C.CString(args.Type) 294 cdesc := C.TF_NewOperation(g.c, ctype, cname) 295 C.free(unsafe.Pointer(cname)) 296 C.free(unsafe.Pointer(ctype)) 297 298 for _, in := range args.Input { 299 switch in := in.(type) { 300 case Output: 301 C.TF_AddInput(cdesc, in.c()) 302 case OutputList: 303 size := len(in) 304 list := make([]C.TF_Output, size) 305 for i, v := range in { 306 list[i] = v.c() 307 } 308 if size > 0 { 309 C.TF_AddInputList(cdesc, &list[0], C.int(size)) 310 } else { 311 C.TF_AddInputList(cdesc, nil, 0) 312 } 313 } 314 } 315 for _, in := range args.ControlDependencies { 316 C.TF_AddControlInput(cdesc, in.c) 317 } 318 status := newStatus() 319 for name, value := range args.Attrs { 320 if err := setAttr(cdesc, status, name, value); err != nil { 321 // Memory leak here as the TF_OperationDescription 322 // object will not be cleaned up. At the time of this 323 // writing, this was next to impossible since it 324 // required value to be a string tensor with 325 // incorrectly encoded strings. Given this rarity, live 326 // with the memory leak. If it becomes a real problem, 327 // consider adding a TF_DeleteOperationDescription 328 // function to the C API. 329 return nil, fmt.Errorf("%v (memory will be leaked)", err) 330 } 331 } 332 if len(args.Device) > 0 { 333 cdevice := C.CString(args.Device) 334 C.TF_SetDevice(cdesc, cdevice) 335 C.free(unsafe.Pointer(cdevice)) 336 } 337 c := C.TF_FinishOperation(cdesc, status.c) 338 if err := status.Err(); err != nil { 339 return nil, err 340 } 341 return &Operation{c, g}, nil 342} 343 344func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { 345 cAttrName := C.CString(name) 346 defer C.free(unsafe.Pointer(cAttrName)) 347 switch value := value.(type) { 348 case string: 349 cstr := C.CString(value) 350 C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value))) 351 C.free(unsafe.Pointer(cstr)) 352 case []string: 353 size := len(value) 354 list := make([]unsafe.Pointer, size) 355 lens := make([]C.size_t, size) 356 for i, s := range value { 357 list[i] = unsafe.Pointer(C.CString(s)) 358 lens[i] = C.size_t(len(s)) 359 } 360 if size > 0 { 361 C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) 362 } else { 363 C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0) 364 } 365 for _, s := range list { 366 C.free(s) 367 } 368 case int64: 369 C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) 370 case []int64: 371 size := len(value) 372 list := make([]C.int64_t, size) 373 for i, v := range value { 374 list[i] = C.int64_t(v) 375 } 376 if size > 0 { 377 C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) 378 } else { 379 C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0) 380 } 381 case float32: 382 C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) 383 case []float32: 384 size := len(value) 385 list := make([]C.float, size) 386 for i, v := range value { 387 list[i] = C.float(v) 388 } 389 if size > 0 { 390 C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) 391 } else { 392 C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0) 393 } 394 case bool: 395 v := C.uchar(0) 396 if value { 397 v = 1 398 } 399 C.TF_SetAttrBool(cdesc, cAttrName, v) 400 case []bool: 401 size := len(value) 402 list := make([]C.uchar, size) 403 for i, v := range value { 404 if v { 405 list[i] = 1 406 } 407 } 408 if size > 0 { 409 C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) 410 } else { 411 C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0) 412 } 413 case DataType: 414 C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) 415 case []DataType: 416 var list *C.TF_DataType 417 if len(value) > 0 { 418 list = (*C.TF_DataType)(&value[0]) 419 } 420 C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) 421 case *Tensor: 422 C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) 423 if err := status.Err(); err != nil { 424 return fmt.Errorf("bad value for attribute %q: %v", name, err) 425 } 426 case []*Tensor: 427 size := len(value) 428 list := make([]*C.TF_Tensor, size) 429 for i, v := range value { 430 list[i] = v.c 431 } 432 var plist **C.TF_Tensor 433 if size > 0 { 434 plist = &list[0] 435 } 436 C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c) 437 if err := status.Err(); err != nil { 438 return fmt.Errorf("bad value for attribute %q: %v", name, err) 439 } 440 case Shape: 441 ndims := C.int(value.NumDimensions()) 442 var dimsp *C.int64_t 443 if ndims > 0 { 444 dims := make([]C.int64_t, ndims) 445 for i, d := range value.dims { 446 dims[i] = C.int64_t(d) 447 } 448 dimsp = &dims[0] 449 } 450 C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) 451 case []Shape: 452 if len(value) == 0 { 453 C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) 454 } else { 455 var flatDims []C.int64_t 456 ndims := make([]C.int, len(value)) 457 for i, s := range value { 458 nd := s.NumDimensions() 459 ndims[i] = C.int(nd) 460 for _, d := range s.dims { 461 flatDims = append(flatDims, C.int64_t(d)) 462 } 463 } 464 var flatDimsp *C.int64_t 465 if len(flatDims) > 0 { 466 flatDimsp = &flatDims[0] 467 } 468 C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value))) 469 } 470 default: 471 return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) 472 } 473 return nil 474} 475