• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import java.util.Iterator;
19 
20 /**
21  * A data flow graph representing a TensorFlow computation.
22  *
23  * <p>Instances of a Graph are thread-safe.
24  *
25  * <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking
26  * the {@link #close()} method then the Graph object is no longer needed.
27  */
28 public final class Graph implements ExecutionEnvironment, AutoCloseable {
29 
30   /** Create an empty Graph. */
Graph()31   public Graph() {
32     nativeHandle = allocate();
33   }
34 
35   /** Create a Graph from an existing handle (takes ownership). */
Graph(long nativeHandle)36   Graph(long nativeHandle) {
37     this.nativeHandle = nativeHandle;
38   }
39 
40   /**
41    * Release resources associated with the Graph.
42    *
43    * <p>Blocks until there are no active {@link Session} instances referring to this Graph. A Graph
44    * is not usable after close returns.
45    */
46   @Override
close()47   public void close() {
48     synchronized (nativeHandleLock) {
49       if (nativeHandle == 0) {
50         return;
51       }
52       while (refcount > 0) {
53         try {
54           nativeHandleLock.wait();
55         } catch (InterruptedException e) {
56           Thread.currentThread().interrupt();
57           // Possible leak of the graph in this case?
58           return;
59         }
60       }
61       delete(nativeHandle);
62       nativeHandle = 0;
63     }
64   }
65 
66   /**
67    * Returns the operation (node in the Graph) with the provided name.
68    *
69    * <p>Or {@code null} if no such operation exists in the Graph.
70    */
operation(String name)71   public GraphOperation operation(String name) {
72     synchronized (nativeHandleLock) {
73       long oph = operation(nativeHandle, name);
74       if (oph == 0) {
75         return null;
76       }
77       return new GraphOperation(this, oph);
78     }
79   }
80 
81   /**
82    * Iterator over all the {@link Operation}s in the graph.
83    *
84    * <p>The order of iteration is unspecified. Consumers of the iterator will receive no
85    * notification should the underlying graph change during iteration.
86    */
operations()87   public Iterator<Operation> operations() {
88     return new OperationIterator(this);
89   }
90 
91   /**
92    * Returns a builder to add {@link Operation}s to the Graph.
93    *
94    * @param type of the Operation (i.e., identifies the computation to be performed)
95    * @param name to refer to the created Operation in the graph.
96    * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link
97    *     OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
98    *     then some resources may leak.
99    */
100   @Override
opBuilder(String type, String name)101   public GraphOperationBuilder opBuilder(String type, String name) {
102     return new GraphOperationBuilder(this, type, name);
103   }
104 
105   /**
106    * Import a serialized representation of a TensorFlow graph.
107    *
108    * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
109    * generated by {@link #toGraphDef()} and equivalents in other language APIs.
110    *
111    * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
112    * @see #importGraphDef(byte[], String)
113    */
importGraphDef(byte[] graphDef)114   public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
115     importGraphDef(graphDef, "");
116   }
117 
118   /**
119    * Import a serialized representation of a TensorFlow graph.
120    *
121    * @param graphDef the serialized representation of a TensorFlow graph.
122    * @param prefix a prefix that will be prepended to names in graphDef
123    * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
124    * @see #importGraphDef(byte[])
125    */
importGraphDef(byte[] graphDef, String prefix)126   public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
127     if (graphDef == null || prefix == null) {
128       throw new IllegalArgumentException("graphDef and prefix cannot be null");
129     }
130     synchronized (nativeHandleLock) {
131       importGraphDef(nativeHandle, graphDef, prefix);
132     }
133   }
134 
135   /**
136    * Generate a serialized representation of the Graph.
137    *
138    * @see #importGraphDef(byte[])
139    * @see #importGraphDef(byte[], String)
140    */
toGraphDef()141   public byte[] toGraphDef() {
142     synchronized (nativeHandleLock) {
143       return toGraphDef(nativeHandle);
144     }
145   }
146 
147   /**
148    * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
149    * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
150    *
151    * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
152    * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
153    * {@code y}.
154    *
155    * <p>If {@code dx} is null, the implementation will use dx of {@link
156    * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
157    *
158    * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
159    * gradients. It must be unique within the provided graph or the operation will fail.
160    *
161    * <p>If {@code prefix} is null, then one will be chosen automatically.
162    *
163    * @param prefix unique string prefix applied before the names of nodes added to the graph to
164    *     compute gradients. If null, a default one will be chosen.
165    * @param y output of the function to derive
166    * @param x inputs of the function for which partial derivatives are computed
167    * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
168    * @return the partial derivatives {@code dy} with the size of {@code x}
169    */
addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx)170   public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
171     Output<?>[] dy = new Output<?>[x.length];
172     final long[] yHandles = new long[y.length];
173     final int[] yIndices = new int[y.length];
174     final long[] xHandles = new long[x.length];
175     final int[] xIndices = new int[x.length];
176     long[] dxHandles = null;
177     int[] dxIndices = null;
178 
179     try (Reference ref = ref()) {
180       for (int i = 0; i < y.length; ++i) {
181         yHandles[i] = y[i].getUnsafeNativeHandle();
182         yIndices[i] = y[i].index();
183       }
184       for (int i = 0; i < x.length; ++i) {
185         xHandles[i] = x[i].getUnsafeNativeHandle();
186         xIndices[i] = x[i].index();
187       }
188       if (dx != null && dx.length > 0) {
189         dxHandles = new long[dx.length];
190         dxIndices = new int[dx.length];
191 
192         for (int i = 0; i < dx.length; ++i) {
193           dxHandles[i] = dx[i].getUnsafeNativeHandle();
194           dxIndices[i] = dx[i].index();
195         }
196       }
197       // Gradient outputs are returned in two continuous arrays concatenated into one. The first
198       // holds the native handles of the gradient operations while the second holds the index of
199       // their output e.g. given
200       // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
201       // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
202       long[] dyHandlesAndIndices =
203           addGradients(
204               ref.nativeHandle(),
205               prefix,
206               yHandles,
207               yIndices,
208               xHandles,
209               xIndices,
210               dxHandles,
211               dxIndices);
212       int ndy = dyHandlesAndIndices.length >> 1;
213       if (ndy != dy.length) {
214         throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
215             + " were expected");
216       }
217       for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
218         GraphOperation op = new GraphOperation(this, dyHandlesAndIndices[i]);
219         dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
220       }
221     }
222     return dy;
223   }
224 
225   /**
226    * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
227    * i.e., {@code dy/dx_1, dy/dx_2...}
228    * <p>
229    * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])}
230    * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null.
231    *
232    * @param y output of the function to derive
233    * @param x inputs of the function for which partial derivatives are computed
234    * @return the partial derivatives {@code dy} with the size of {@code x}
235    */
addGradients(Output<?> y, Output<?>[] x)236   public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
237     return addGradients(null, new Output<?>[] {y}, x, null);
238   }
239 
240   /**
241    * Used to instantiate an abstract class which overrides the buildSubgraph method to build a
242    * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
243    * create a lambda for the same purpose.
244    *
245    * <p>To be used when calling {@link #whileLoop(Output[],
246    * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
247    *
248    * <p>Example usage (prior to Java 8):
249    *
250    * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public
251    * void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build
252    * body subgraph } }; }
253    *
254    * <p>Example usage (after Java 8):
255    *
256    * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
257    * build body subgraph };}
258    */
259   public interface WhileSubgraphBuilder {
260     /**
261      * To be overridden by user with code to build conditional or body subgraph for a while loop
262      *
263      * @param g the subgraph
264      * @param inputs subgraph inputs
265      * @param outputs subgraph outputs
266      */
buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs)267     public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs);
268   }
269 
270   // called by while loop code in graph_jni.cc to construct conditional/body subgraphs
buildSubgraph( WhileSubgraphBuilder subgraphBuilder, long subgraphHandle, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices)271   private static long[] buildSubgraph(
272       WhileSubgraphBuilder subgraphBuilder,
273       long subgraphHandle,
274       long[] inputHandles,
275       int[] inputIndices,
276       long[] outputHandles,
277       int[] outputIndices) {
278     Graph subgraph = new Graph(subgraphHandle);
279 
280     int ninputs = inputHandles.length;
281     int noutputs = outputHandles.length;
282     Output<?>[] inputs = new Output<?>[ninputs];
283     Output<?>[] outputs = new Output<?>[noutputs];
284     long[] outputHandlesAndIndices = new long[noutputs * 2];
285 
286     synchronized (subgraph.nativeHandleLock) {
287       try (Reference ref = subgraph.ref()) {
288 
289         for (int i = 0; i < ninputs; i++) {
290           Operation op = new GraphOperation(subgraph, inputHandles[i]);
291           inputs[i] = op.output(inputIndices[i]);
292         }
293 
294         for (int i = 0; i < noutputs; i++) {
295           Operation op = new GraphOperation(subgraph, outputHandles[i]);
296           outputs[i] = op.output(outputIndices[i]);
297         }
298 
299         subgraphBuilder.buildSubgraph(subgraph, inputs, outputs);
300 
301         for (int i = 0, j = noutputs; i < noutputs; i++, j++) {
302           outputHandlesAndIndices[i] = outputs[i].getUnsafeNativeHandle();
303           outputHandlesAndIndices[j] = (long) outputs[i].index();
304         }
305       }
306       return outputHandlesAndIndices;
307     }
308   }
309 
310   /**
311    * Builds a while loop.
312    *
313    * @param inputs the loop inputs
314    * @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph
315    * @param bgBuilder WhileSubgraphBuilder to build the body subgraph
316    * @param name name for the loop
317    * @return list of loop outputs, of the same length as {@code inputs}
318    */
whileLoop( Output<?>[] inputs, WhileSubgraphBuilder cgBuilder, WhileSubgraphBuilder bgBuilder, String name)319   public Output<?>[] whileLoop(
320       Output<?>[] inputs,
321       WhileSubgraphBuilder cgBuilder,
322       WhileSubgraphBuilder bgBuilder,
323       String name) {
324     int ninputs = inputs.length;
325     long[] inputHandles = new long[ninputs];
326     int[] inputIndices = new int[ninputs];
327     Output<?>[] outputs = new Output<?>[ninputs];
328 
329     synchronized (nativeHandleLock) {
330       try (Reference ref = ref()) {
331 
332         for (int i = 0; i < ninputs; i++) {
333           inputHandles[i] = inputs[i].getUnsafeNativeHandle();
334           inputIndices[i] = inputs[i].index();
335         }
336 
337         long[] outputHandlesAndIndices =
338             whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder);
339 
340         for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
341           Operation op = new GraphOperation(this, outputHandlesAndIndices[i]);
342           outputs[i] = op.output((int) outputHandlesAndIndices[j]);
343         }
344       }
345       return outputs;
346     }
347   }
348 
349   private final Object nativeHandleLock = new Object();
350   private long nativeHandle;
351   private int refcount = 0;
352 
353   // Related native objects (such as the TF_Operation object backing an Operation instance)
354   // have a validity tied to that of the Graph. The handles to those native objects are not
355   // valid after Graph.close() has been invoked.
356   //
357   // Instances of the Reference class should be used to ensure the Graph has not been closed
358   // while dependent handles are in use.
359   class Reference implements AutoCloseable {
Reference()360     private Reference() {
361       synchronized (Graph.this.nativeHandleLock) {
362         active = Graph.this.nativeHandle != 0;
363         if (!active) {
364           throw new IllegalStateException("close() has been called on the Graph");
365         }
366         active = true;
367         Graph.this.refcount++;
368       }
369     }
370 
371     @Override
close()372     public void close() {
373       synchronized (Graph.this.nativeHandleLock) {
374         if (!active) {
375           return;
376         }
377         active = false;
378         if (--Graph.this.refcount == 0) {
379           Graph.this.nativeHandleLock.notifyAll();
380         }
381       }
382     }
383 
nativeHandle()384     public long nativeHandle() {
385       synchronized (Graph.this.nativeHandleLock) {
386         return active ? Graph.this.nativeHandle : 0;
387       }
388     }
389 
390     private boolean active;
391   }
392 
ref()393   Reference ref() {
394     return new Reference();
395   }
396 
397   private static final class OperationIterator implements Iterator<Operation> {
398 
OperationIterator(Graph g)399     OperationIterator(Graph g) {
400       this.graph = g;
401       this.operation = null;
402       this.position = 0;
403       this.advance();
404     }
405 
advance()406     private final void advance() {
407       Graph.Reference reference = this.graph.ref();
408 
409       this.operation = null;
410 
411       try {
412         long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);
413 
414         if ((nativeReturn != null) && (nativeReturn[0] != 0)) {
415           this.operation = new GraphOperation(this.graph, nativeReturn[0]);
416           this.position = (int) nativeReturn[1];
417         }
418       } finally {
419         reference.close();
420       }
421     }
422 
423     @Override
hasNext()424     public boolean hasNext() {
425       return (this.operation != null);
426     }
427 
428     @Override
next()429     public Operation next() {
430       Operation rhett = this.operation;
431       this.advance();
432       return rhett;
433     }
434 
435     @Override
remove()436     public void remove() {
437       throw new UnsupportedOperationException("remove() is unsupported.");
438     }
439 
440     private final Graph graph;
441     private Operation operation;
442     private int position;
443   }
444 
allocate()445   private static native long allocate();
446 
delete(long handle)447   private static native void delete(long handle);
448 
operation(long handle, String name)449   private static native long operation(long handle, String name);
450 
451   // This method returns the Operation native handle at index 0 and the new value for pos at index 1
452   // (see TF_GraphNextOperation)
nextOperation(long handle, int position)453   private static native long[] nextOperation(long handle, int position);
454 
importGraphDef(long handle, byte[] graphDef, String prefix)455   private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
456       throws IllegalArgumentException;
457 
toGraphDef(long handle)458   private static native byte[] toGraphDef(long handle);
459 
addGradients( long handle, String prefix, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices)460   private static native long[] addGradients(
461       long handle,
462       String prefix,
463       long[] inputHandles,
464       int[] inputIndices,
465       long[] outputHandles,
466       int[] outputIndices,
467       long[] gradInputHandles,
468       int[] gradInputIndices);
469 
whileLoop( long handle, long[] inputHandles, int[] inputIndices, String name, WhileSubgraphBuilder condGraphBuilder, WhileSubgraphBuilder bodyGraphBuilder)470   private static native long[] whileLoop(
471       long handle,
472       long[] inputHandles,
473       int[] inputIndices,
474       String name,
475       WhileSubgraphBuilder condGraphBuilder,
476       WhileSubgraphBuilder bodyGraphBuilder);
477 
478   static {
TensorFlow.init()479     TensorFlow.init();
480   }
481 }
482