• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19// #include "tensorflow/c/c_api.h"
20//
21// #include <stdlib.h>
22// #include <string.h>
23//
24// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
25//                                 const char* attr_name,
26//                                 const int64_t* flat_dims,
27//                                 const int* num_dims,
28//                                 int num_shapes) {
29//  const int64_t** dims =
30//    (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
31//  int i = 0;
32//  for (i = 0; i < num_shapes; i++) {
33//    dims[i] = flat_dims;
34//    if (num_dims[i] > 0) {
35//      // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
36//      flat_dims += num_dims[i];
37//    }
38//  }
39//  TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
40//  free(dims);
41// }
42import "C"
43
44import (
45	"fmt"
46	"io"
47	"runtime"
48	"unsafe"
49)
50
51// Graph represents a computation graph. Graphs may be shared between sessions.
52type Graph struct {
53	c *C.TF_Graph
54}
55
56// The GraphImportOptions struct holds parameters for the ImportWithOptions function.
57type GraphImportOptions struct {
58	// Node prefix
59	Prefix string
60
61	// Execution device
62	Device string
63
64	// inputMapping defines a mapping between Outputs in the graph
65	// and Outputs they should be replaced with.
66	inputMapping map[struct {
67		Name  string
68		Index int
69	}]Output
70
71	// TODO: extend this structure to support more options from TF_ImportGraphDefOptions
72}
73
74// AddInputMapping adds a mapping between an Output in the imported graph
75// and an Output in the destination graph that it should be replaced with,
76// where src:srcIndex is the name of the Operation and Output index to
77// replace and dst is the output to replace it with.
78func (o *GraphImportOptions) AddInputMapping(src string, srcIndex int, dst Output) {
79	if o.inputMapping == nil {
80		o.inputMapping = make(map[struct {
81			Name  string
82			Index int
83		}]Output)
84	}
85	o.inputMapping[struct {
86		Name  string
87		Index int
88	}{src, srcIndex}] = dst
89}
90
91// NewGraph returns a new Graph.
92func NewGraph() *Graph {
93	g := &Graph{C.TF_NewGraph()}
94	runtime.SetFinalizer(g, (*Graph).finalizer)
95	return g
96}
97
98func (g *Graph) finalizer() {
99	C.TF_DeleteGraph(g.c)
100}
101
102// WriteTo writes out a serialized representation of g to w.
103//
104// Implements the io.WriterTo interface.
105func (g *Graph) WriteTo(w io.Writer) (int64, error) {
106	buf := C.TF_NewBuffer()
107	defer C.TF_DeleteBuffer(buf)
108	status := newStatus()
109	C.TF_GraphToGraphDef(g.c, buf, status.c)
110	if err := status.Err(); err != nil {
111		return 0, err
112	}
113	if buf.length > (1 << 30) {
114		// For very large graphs, the writes can be chunked.
115		// Punt on that for now.
116		return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated")
117	}
118	// A []byte slice backed by C memory.
119	// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
120	length := int(buf.length)
121	var slice []byte
122	if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 {
123		slice = (*[1<<50 - 1]byte)(unsafe.Pointer(buf.data))[:length:length]
124	} else {
125		slice = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length]
126	}
127	n, err := w.Write(slice)
128	return int64(n), err
129}
130
131// ImportWithOptions imports the nodes and edges from a serialized representation of
132// another Graph into g.
133//
134// Multiple options can be specified for the newly imported nodes.
135func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error {
136	cprefix := C.CString(options.Prefix)
137	defer C.free(unsafe.Pointer(cprefix))
138
139	opts := C.TF_NewImportGraphDefOptions()
140	defer C.TF_DeleteImportGraphDefOptions(opts)
141	C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix)
142
143	if len(options.Device) != 0 {
144		cdev := C.CString(options.Device)
145		defer C.free(unsafe.Pointer(cdev))
146		C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
147	}
148
149	for src, dst := range options.inputMapping {
150		cSrcName := C.CString(src.Name)
151		C.TF_ImportGraphDefOptionsAddInputMapping(opts, cSrcName, C.int(src.Index), dst.c())
152		C.free(unsafe.Pointer(cSrcName))
153	}
154
155	buf := C.TF_NewBuffer()
156	defer C.TF_DeleteBuffer(buf)
157	buf.length = C.size_t(len(def))
158	buf.data = C.CBytes(def)
159	if buf.data == nil {
160		return fmt.Errorf("unable to allocate memory")
161	}
162	defer C.free(buf.data)
163
164	status := newStatus()
165
166	C.TF_GraphImportGraphDef(g.c, buf, opts, status.c)
167	if err := status.Err(); err != nil {
168		return err
169	}
170
171	return nil
172}
173
174// Import imports the nodes and edges from a serialized representation of
175// another Graph into g.
176//
177// Names of imported nodes will be prefixed with prefix.
178func (g *Graph) Import(def []byte, prefix string) error {
179	return g.ImportWithOptions(def, GraphImportOptions{Prefix: prefix})
180}
181
182// Operation returns the Operation named name in the Graph, or nil if no such
183// operation is present.
184func (g *Graph) Operation(name string) *Operation {
185	cname := C.CString(name)
186	defer C.free(unsafe.Pointer(cname))
187	cop := C.TF_GraphOperationByName(g.c, cname)
188	if cop == nil {
189		return nil
190	}
191	return &Operation{cop, g}
192}
193
194// Operations returns a list of all operations in the graph
195func (g *Graph) Operations() []Operation {
196	var pos C.size_t
197	ops := []Operation{}
198	for {
199		cop := C.TF_GraphNextOperation(g.c, &pos)
200		if cop == nil {
201			break
202		}
203		ops = append(ops, Operation{cop, g})
204	}
205	return ops
206}
207
208// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y
209// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc.
210//
211// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute
212// these gradients.
213func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
214	var (
215		cprefix *C.char
216
217		cy  = make([]C.TF_Output, len(y))
218		cx  = make([]C.TF_Output, len(x))
219		cdx = make([]C.TF_Output, len(dx))
220		cdy = make([]C.TF_Output, len(x))
221
222		pcy  *C.TF_Output
223		pcx  *C.TF_Output
224		pcdx *C.TF_Output
225		pcdy *C.TF_Output
226
227		status = newStatus()
228	)
229
230	if len(y) > 0 {
231		pcy = &cy[0]
232		for i, o := range y {
233			cy[i] = o.c()
234		}
235	}
236	if len(x) > 0 {
237		pcx = &cx[0]
238		for i, o := range x {
239			cx[i] = o.c()
240		}
241		pcdy = &cdy[0]
242	}
243	if len(dx) > 0 {
244		pcdx = &cdx[0]
245		for i, o := range dx {
246			cdx[i] = o.c()
247		}
248	}
249
250	// If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not ""
251	if len(prefix) != 0 {
252		cprefix = C.CString(prefix)
253		defer C.free(unsafe.Pointer(cprefix))
254	}
255
256	C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
257
258	if err := status.Err(); err != nil {
259		return nil, err
260	}
261	dy := make([]Output, len(x))
262	for i, co := range cdy {
263		op := &Operation{co.oper, g}
264		dy[i] = Output{op, int(co.index)}
265	}
266
267	return dy, nil
268}
269
270// OpSpec is the specification of an Operation to be added to a Graph
271// (using Graph.AddOperation).
272type OpSpec struct {
273	// Type of the operation (e.g., "Add", "MatMul").
274	Type string
275
276	// Name by which the added operation will be referred to in the Graph.
277	// If omitted, defaults to Type.
278	Name string
279
280	// Inputs to this operation, which in turn must be outputs
281	// of other operations already added to the Graph.
282	//
283	// An operation may have multiple inputs with individual inputs being
284	// either a single tensor produced by another operation or a list of
285	// tensors produced by multiple operations. For example, the "Concat"
286	// operation takes two inputs: (1) the dimension along which to
287	// concatenate and (2) a list of tensors to concatenate. Thus, for
288	// Concat, len(Input) must be 2, with the first element being an Output
289	// and the second being an OutputList.
290	Input []Input
291
292	// Map from attribute name to its value that will be attached to this
293	// operation.
294	Attrs map[string]interface{}
295
296	// Operations that must be executed before executing the operation
297	// being added.
298	ControlDependencies []*Operation
299
300	// The device on which the operation should be executed.
301	// If omitted, an appropriate device will automatically be selected.
302	//
303	// For example, if set of "/device:GPU:0", then the operation will
304	// execute on GPU #0.
305	Device string
306
307	// Other possible fields: ColocateWith.
308}
309
310// AddOperation adds an operation to g.
311func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
312	if args.Name == "" {
313		args.Name = args.Type
314	}
315	cname := C.CString(args.Name)
316	ctype := C.CString(args.Type)
317	cdesc := C.TF_NewOperation(g.c, ctype, cname)
318	C.free(unsafe.Pointer(cname))
319	C.free(unsafe.Pointer(ctype))
320
321	for _, in := range args.Input {
322		switch in := in.(type) {
323		case Output:
324			C.TF_AddInput(cdesc, in.c())
325		case OutputList:
326			size := len(in)
327			list := make([]C.TF_Output, size)
328			for i, v := range in {
329				list[i] = v.c()
330			}
331			if size > 0 {
332				C.TF_AddInputList(cdesc, &list[0], C.int(size))
333			} else {
334				C.TF_AddInputList(cdesc, nil, 0)
335			}
336		}
337	}
338	for _, in := range args.ControlDependencies {
339		C.TF_AddControlInput(cdesc, in.c)
340	}
341	status := newStatus()
342	for name, value := range args.Attrs {
343		if err := setAttr(cdesc, status, name, value); err != nil {
344			// Memory leak here as the TF_OperationDescription
345			// object will not be cleaned up. At the time of this
346			// writing, this was next to impossible since it
347			// required value to be a string tensor with
348			// incorrectly encoded strings. Given this rarity, live
349			// with the memory leak.  If it becomes a real problem,
350			// consider adding a TF_DeleteOperationDescription
351			// function to the C API.
352			return nil, fmt.Errorf("%v (memory will be leaked)", err)
353		}
354	}
355	if len(args.Device) > 0 {
356		cdevice := C.CString(args.Device)
357		C.TF_SetDevice(cdesc, cdevice)
358		C.free(unsafe.Pointer(cdevice))
359	}
360	c := C.TF_FinishOperation(cdesc, status.c)
361	if err := status.Err(); err != nil {
362		return nil, err
363	}
364	return &Operation{c, g}, nil
365}
366
367func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
368	cAttrName := C.CString(name)
369	defer C.free(unsafe.Pointer(cAttrName))
370	switch value := value.(type) {
371	case string:
372		cstr := C.CString(value)
373		C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value)))
374		C.free(unsafe.Pointer(cstr))
375	case []string:
376		size := len(value)
377		list := make([]unsafe.Pointer, size)
378		lens := make([]C.size_t, size)
379		for i, s := range value {
380			list[i] = unsafe.Pointer(C.CString(s))
381			lens[i] = C.size_t(len(s))
382		}
383		if size > 0 {
384			C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
385		} else {
386			C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0)
387		}
388		for _, s := range list {
389			C.free(s)
390		}
391	case int64:
392		C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
393	case []int64:
394		size := len(value)
395		list := make([]C.int64_t, size)
396		for i, v := range value {
397			list[i] = C.int64_t(v)
398		}
399		if size > 0 {
400			C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
401		} else {
402			C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0)
403		}
404	case float32:
405		C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
406	case []float32:
407		size := len(value)
408		list := make([]C.float, size)
409		for i, v := range value {
410			list[i] = C.float(v)
411		}
412		if size > 0 {
413			C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
414		} else {
415			C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0)
416		}
417	case bool:
418		v := C.uchar(0)
419		if value {
420			v = 1
421		}
422		C.TF_SetAttrBool(cdesc, cAttrName, v)
423	case []bool:
424		size := len(value)
425		list := make([]C.uchar, size)
426		for i, v := range value {
427			if v {
428				list[i] = 1
429			}
430		}
431		if size > 0 {
432			C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
433		} else {
434			C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0)
435		}
436	case DataType:
437		C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
438	case []DataType:
439		var list *C.TF_DataType
440		if len(value) > 0 {
441			list = (*C.TF_DataType)(&value[0])
442		}
443		C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
444	case *Tensor:
445		C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c)
446		if err := status.Err(); err != nil {
447			return fmt.Errorf("bad value for attribute %q: %v", name, err)
448		}
449	case []*Tensor:
450		size := len(value)
451		list := make([]*C.TF_Tensor, size)
452		for i, v := range value {
453			list[i] = v.c
454		}
455		var plist **C.TF_Tensor
456		if size > 0 {
457			plist = &list[0]
458		}
459		C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c)
460		if err := status.Err(); err != nil {
461			return fmt.Errorf("bad value for attribute %q: %v", name, err)
462		}
463	case Shape:
464		ndims := C.int(value.NumDimensions())
465		var dimsp *C.int64_t
466		if ndims > 0 {
467			dims := make([]C.int64_t, ndims)
468			for i, d := range value.dims {
469				dims[i] = C.int64_t(d)
470			}
471			dimsp = &dims[0]
472		}
473		C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
474	case []Shape:
475		if len(value) == 0 {
476			C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
477		} else {
478			var flatDims []C.int64_t
479			ndims := make([]C.int, len(value))
480			for i, s := range value {
481				nd := s.NumDimensions()
482				ndims[i] = C.int(nd)
483				for _, d := range s.dims {
484					flatDims = append(flatDims, C.int64_t(d))
485				}
486			}
487			var flatDimsp *C.int64_t
488			if len(flatDims) > 0 {
489				flatDimsp = &flatDims[0]
490			}
491			C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
492		}
493	default:
494		return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
495	}
496	return nil
497}
498
499type LibraryHandler struct {
500	cptr *C.TF_Library
501}
502
503// Load library content into current context, useful to load ops implementation into non-monolithic TF build. Returns LibraryHandler or nil and error
504func LoadLibrary(path string) (*LibraryHandler, error) {
505	status := newStatus()
506
507	cpath := C.CString(path)
508	defer C.free(unsafe.Pointer(cpath))
509	cptr := C.TF_LoadLibrary(cpath, status.c)
510	if cptr == nil || status.Code() != C.TF_OK {
511		return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String())
512	}
513
514	lh := &LibraryHandler{
515		cptr: cptr,
516	}
517
518	runtime.SetFinalizer(lh, (*LibraryHandler).free)
519	return lh, nil
520}
521
522func (lh *LibraryHandler) free() {
523	if lh == nil || lh.cptr == nil {
524		return
525	}
526
527	C.TF_DeleteLibraryHandle(lh.cptr)
528}
529