• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19// #include "tensorflow/c/c_api.h"
20//
21// #include <stdlib.h>
22// #include <string.h>
23//
24// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
25//                                 const char* attr_name,
26//                                 const int64_t* flat_dims,
27//                                 const int* num_dims,
28//                                 int num_shapes) {
29//  const int64_t** dims =
30//    (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
31//  int i = 0;
32//  for (i = 0; i < num_shapes; i++) {
33//    dims[i] = flat_dims;
34//    if (num_dims[i] > 0) {
35//      // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
36//      flat_dims += num_dims[i];
37//    }
38//  }
39//  TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
40//  free(dims);
41// }
42import "C"
43
44import (
45	"fmt"
46	"io"
47	"runtime"
48	"unsafe"
49)
50
51// Graph represents a computation graph. Graphs may be shared between sessions.
52type Graph struct {
53	c *C.TF_Graph
54}
55
56// Graph execution options
57type GraphImportOptions struct {
58	// Node prefix
59	Prefix string
60
61	// Execution device
62	Device string
63
64	// TODO: extend this structure to support more options from TF_ImportGraphDefOptions
65}
66
67// NewGraph returns a new Graph.
68func NewGraph() *Graph {
69	g := &Graph{C.TF_NewGraph()}
70	runtime.SetFinalizer(g, (*Graph).finalizer)
71	return g
72}
73
74func (g *Graph) finalizer() {
75	C.TF_DeleteGraph(g.c)
76}
77
78// WriteTo writes out a serialized representation of g to w.
79//
80// Implements the io.WriterTo interface.
81func (g *Graph) WriteTo(w io.Writer) (int64, error) {
82	buf := C.TF_NewBuffer()
83	defer C.TF_DeleteBuffer(buf)
84	status := newStatus()
85	C.TF_GraphToGraphDef(g.c, buf, status.c)
86	if err := status.Err(); err != nil {
87		return 0, err
88	}
89	if buf.length > (1 << 30) {
90		// For very large graphs, the writes can be chunked.
91		// Punt on that for now.
92		return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated")
93	}
94	// A []byte slice backed by C memory.
95	// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
96	length := int(buf.length)
97	slice := (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length]
98	n, err := w.Write(slice)
99	return int64(n), err
100}
101
102// ImportWithOptions imports the nodes and edges from a serialized representation of
103// another Graph into g.
104//
105// Multiple options can be specified for the newly imported nodes.
106func (g *Graph) ImportWithOptions(def []byte, options GraphImportOptions) error {
107	cprefix := C.CString(options.Prefix)
108	defer C.free(unsafe.Pointer(cprefix))
109
110	opts := C.TF_NewImportGraphDefOptions()
111	defer C.TF_DeleteImportGraphDefOptions(opts)
112	C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix)
113
114	if len(options.Device) != 0 {
115		// TODO(ashankar): Remove this error and uncomment below
116		// when a release of the C library which includes
117		// https://github.com/tensorflow/tensorflow/commit/e0af5ac53e5a8ad9b07cdd5738c0a8e12f938c4e
118		// has been made.
119		// See https://github.com/tensorflow/tensorflow/issues/23257
120		return fmt.Errorf("GraphImportOptions.Device is only supported with the TensorFlow C library versions after 1.12 (or built from master). See https://github.com/tensorflow/tensorflow/issues/23257")
121		/*
122			cdev := C.CString(options.Device)
123			defer C.free(unsafe.Pointer(cdev))
124			C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
125		*/
126	}
127
128	buf := C.TF_NewBuffer()
129	defer C.TF_DeleteBuffer(buf)
130	// Would have preferred to use C.CBytes, but that does not play well
131	// with "go vet" till https://github.com/golang/go/issues/17201 is
132	// resolved.
133	buf.length = C.size_t(len(def))
134	buf.data = C.malloc(buf.length)
135	if buf.data == nil {
136		return fmt.Errorf("unable to allocate memory")
137	}
138	defer C.free(buf.data)
139	C.memcpy(buf.data, unsafe.Pointer(&def[0]), buf.length)
140
141	status := newStatus()
142
143	C.TF_GraphImportGraphDef(g.c, buf, opts, status.c)
144	if err := status.Err(); err != nil {
145		return err
146	}
147
148	return nil
149}
150
151// Import imports the nodes and edges from a serialized representation of
152// another Graph into g.
153//
154// Names of imported nodes will be prefixed with prefix.
155func (g *Graph) Import(def []byte, prefix string) error {
156	return g.ImportWithOptions(def, GraphImportOptions{Prefix: prefix})
157}
158
159// Operation returns the Operation named name in the Graph, or nil if no such
160// operation is present.
161func (g *Graph) Operation(name string) *Operation {
162	cname := C.CString(name)
163	defer C.free(unsafe.Pointer(cname))
164	cop := C.TF_GraphOperationByName(g.c, cname)
165	if cop == nil {
166		return nil
167	}
168	return &Operation{cop, g}
169}
170
171// Operations returns a list of all operations in the graph
172func (g *Graph) Operations() []Operation {
173	var pos C.size_t = 0
174	ops := []Operation{}
175	for {
176		cop := C.TF_GraphNextOperation(g.c, &pos)
177		if cop == nil {
178			break
179		}
180		ops = append(ops, Operation{cop, g})
181	}
182	return ops
183}
184
185// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y
186// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc.
187//
188// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute
189// these gradients.
190func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
191	var (
192		cprefix *C.char
193
194		cy  = make([]C.TF_Output, len(y))
195		cx  = make([]C.TF_Output, len(x))
196		cdx = make([]C.TF_Output, len(dx))
197		cdy = make([]C.TF_Output, len(x))
198
199		pcy  *C.TF_Output
200		pcx  *C.TF_Output
201		pcdx *C.TF_Output
202		pcdy *C.TF_Output
203
204		status = newStatus()
205	)
206
207	if len(y) > 0 {
208		pcy = &cy[0]
209		for i, o := range y {
210			cy[i] = o.c()
211		}
212	}
213	if len(x) > 0 {
214		pcx = &cx[0]
215		for i, o := range x {
216			cx[i] = o.c()
217		}
218		pcdy = &cdy[0]
219	}
220	if len(dx) > 0 {
221		pcdx = &cdx[0]
222		for i, o := range dx {
223			cdx[i] = o.c()
224		}
225	}
226
227	// If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not ""
228	if len(prefix) != 0 {
229		cprefix = C.CString(prefix)
230		defer C.free(unsafe.Pointer(cprefix))
231	}
232
233	C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
234
235	if err := status.Err(); err != nil {
236		return nil, err
237	}
238	dy := make([]Output, len(x))
239	for i, co := range cdy {
240		op := &Operation{co.oper, g}
241		dy[i] = Output{op, int(co.index)}
242	}
243
244	return dy, nil
245}
246
247// OpSpec is the specification of an Operation to be added to a Graph
248// (using Graph.AddOperation).
249type OpSpec struct {
250	// Type of the operation (e.g., "Add", "MatMul").
251	Type string
252
253	// Name by which the added operation will be referred to in the Graph.
254	// If omitted, defaults to Type.
255	Name string
256
257	// Inputs to this operation, which in turn must be outputs
258	// of other operations already added to the Graph.
259	//
260	// An operation may have multiple inputs with individual inputs being
261	// either a single tensor produced by another operation or a list of
262	// tensors produced by multiple operations. For example, the "Concat"
263	// operation takes two inputs: (1) the dimension along which to
264	// concatenate and (2) a list of tensors to concatenate. Thus, for
265	// Concat, len(Input) must be 2, with the first element being an Output
266	// and the second being an OutputList.
267	Input []Input
268
269	// Map from attribute name to its value that will be attached to this
270	// operation.
271	Attrs map[string]interface{}
272
273	// Operations that must be executed before executing the operation
274	// being added.
275	ControlDependencies []*Operation
276
277	// The device on which the operation should be executed.
278	// If omitted, an appropriate device will automatically be selected.
279	//
280	// For example, if set of "/device:GPU:0", then the operation will
281	// execute on GPU #0.
282	Device string
283
284	// Other possible fields: ColocateWith.
285}
286
287// AddOperation adds an operation to g.
288func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
289	if args.Name == "" {
290		args.Name = args.Type
291	}
292	cname := C.CString(args.Name)
293	ctype := C.CString(args.Type)
294	cdesc := C.TF_NewOperation(g.c, ctype, cname)
295	C.free(unsafe.Pointer(cname))
296	C.free(unsafe.Pointer(ctype))
297
298	for _, in := range args.Input {
299		switch in := in.(type) {
300		case Output:
301			C.TF_AddInput(cdesc, in.c())
302		case OutputList:
303			size := len(in)
304			list := make([]C.TF_Output, size)
305			for i, v := range in {
306				list[i] = v.c()
307			}
308			if size > 0 {
309				C.TF_AddInputList(cdesc, &list[0], C.int(size))
310			} else {
311				C.TF_AddInputList(cdesc, nil, 0)
312			}
313		}
314	}
315	for _, in := range args.ControlDependencies {
316		C.TF_AddControlInput(cdesc, in.c)
317	}
318	status := newStatus()
319	for name, value := range args.Attrs {
320		if err := setAttr(cdesc, status, name, value); err != nil {
321			// Memory leak here as the TF_OperationDescription
322			// object will not be cleaned up. At the time of this
323			// writing, this was next to impossible since it
324			// required value to be a string tensor with
325			// incorrectly encoded strings. Given this rarity, live
326			// with the memory leak.  If it becomes a real problem,
327			// consider adding a TF_DeleteOperationDescription
328			// function to the C API.
329			return nil, fmt.Errorf("%v (memory will be leaked)", err)
330		}
331	}
332	if len(args.Device) > 0 {
333		cdevice := C.CString(args.Device)
334		C.TF_SetDevice(cdesc, cdevice)
335		C.free(unsafe.Pointer(cdevice))
336	}
337	c := C.TF_FinishOperation(cdesc, status.c)
338	if err := status.Err(); err != nil {
339		return nil, err
340	}
341	return &Operation{c, g}, nil
342}
343
344func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
345	cAttrName := C.CString(name)
346	defer C.free(unsafe.Pointer(cAttrName))
347	switch value := value.(type) {
348	case string:
349		cstr := C.CString(value)
350		C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value)))
351		C.free(unsafe.Pointer(cstr))
352	case []string:
353		size := len(value)
354		list := make([]unsafe.Pointer, size)
355		lens := make([]C.size_t, size)
356		for i, s := range value {
357			list[i] = unsafe.Pointer(C.CString(s))
358			lens[i] = C.size_t(len(s))
359		}
360		if size > 0 {
361			C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
362		} else {
363			C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0)
364		}
365		for _, s := range list {
366			C.free(s)
367		}
368	case int64:
369		C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
370	case []int64:
371		size := len(value)
372		list := make([]C.int64_t, size)
373		for i, v := range value {
374			list[i] = C.int64_t(v)
375		}
376		if size > 0 {
377			C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
378		} else {
379			C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0)
380		}
381	case float32:
382		C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
383	case []float32:
384		size := len(value)
385		list := make([]C.float, size)
386		for i, v := range value {
387			list[i] = C.float(v)
388		}
389		if size > 0 {
390			C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
391		} else {
392			C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0)
393		}
394	case bool:
395		v := C.uchar(0)
396		if value {
397			v = 1
398		}
399		C.TF_SetAttrBool(cdesc, cAttrName, v)
400	case []bool:
401		size := len(value)
402		list := make([]C.uchar, size)
403		for i, v := range value {
404			if v {
405				list[i] = 1
406			}
407		}
408		if size > 0 {
409			C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
410		} else {
411			C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0)
412		}
413	case DataType:
414		C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
415	case []DataType:
416		var list *C.TF_DataType
417		if len(value) > 0 {
418			list = (*C.TF_DataType)(&value[0])
419		}
420		C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
421	case *Tensor:
422		C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c)
423		if err := status.Err(); err != nil {
424			return fmt.Errorf("bad value for attribute %q: %v", name, err)
425		}
426	case []*Tensor:
427		size := len(value)
428		list := make([]*C.TF_Tensor, size)
429		for i, v := range value {
430			list[i] = v.c
431		}
432		var plist **C.TF_Tensor
433		if size > 0 {
434			plist = &list[0]
435		}
436		C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c)
437		if err := status.Err(); err != nil {
438			return fmt.Errorf("bad value for attribute %q: %v", name, err)
439		}
440	case Shape:
441		ndims := C.int(value.NumDimensions())
442		var dimsp *C.int64_t
443		if ndims > 0 {
444			dims := make([]C.int64_t, ndims)
445			for i, d := range value.dims {
446				dims[i] = C.int64_t(d)
447			}
448			dimsp = &dims[0]
449		}
450		C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
451	case []Shape:
452		if len(value) == 0 {
453			C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
454		} else {
455			var flatDims []C.int64_t
456			ndims := make([]C.int, len(value))
457			for i, s := range value {
458				nd := s.NumDimensions()
459				ndims[i] = C.int(nd)
460				for _, d := range s.dims {
461					flatDims = append(flatDims, C.int64_t(d))
462				}
463			}
464			var flatDimsp *C.int64_t
465			if len(flatDims) > 0 {
466				flatDimsp = &flatDims[0]
467			}
468			C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
469		}
470	default:
471		return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
472	}
473	return nil
474}
475