• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import java.util.Iterator;
19 
20 /**
21  * A data flow graph representing a TensorFlow computation.
22  *
23  * <p>Instances of a Graph are thread-safe.
24  *
25  * <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking
26  * the {@link #close()} method then the Graph object is no longer needed.
27  */
28 public final class Graph implements AutoCloseable {
29 
30   /** Create an empty Graph. */
Graph()31   public Graph() {
32     nativeHandle = allocate();
33   }
34 
35   /** Create a Graph from an existing handle (takes ownership). */
Graph(long nativeHandle)36   Graph(long nativeHandle) {
37     this.nativeHandle = nativeHandle;
38   }
39 
40   /**
41    * Release resources associated with the Graph.
42    *
43    * <p>Blocks until there are no active {@link Session} instances referring to this Graph. A Graph
44    * is not usable after close returns.
45    */
46   @Override
close()47   public void close() {
48     synchronized (nativeHandleLock) {
49       if (nativeHandle == 0) {
50         return;
51       }
52       while (refcount > 0) {
53         try {
54           nativeHandleLock.wait();
55         } catch (InterruptedException e) {
56           Thread.currentThread().interrupt();
57           // Possible leak of the graph in this case?
58           return;
59         }
60       }
61       delete(nativeHandle);
62       nativeHandle = 0;
63     }
64   }
65 
66   /**
67    * Returns the operation (node in the Graph) with the provided name.
68    *
69    * <p>Or {@code null} if no such operation exists in the Graph.
70    */
operation(String name)71   public Operation operation(String name) {
72     synchronized (nativeHandleLock) {
73       long oph = operation(nativeHandle, name);
74       if (oph == 0) {
75         return null;
76       }
77       return new Operation(this, oph);
78     }
79   }
80 
81   /**
82    * Iterator over all the {@link Operation}s in the graph.
83    *
84    * <p>The order of iteration is unspecified. Consumers of the iterator will receive no
85    * notification should the underlying graph change during iteration.
86    */
operations()87   public Iterator<Operation> operations() {
88     return new OperationIterator(this);
89   }
90 
91   /**
92    * Returns a builder to add {@link Operation}s to the Graph.
93    *
94    * @param type of the Operation (i.e., identifies the computation to be performed)
95    * @param name to refer to the created Operation in the graph.
96    * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link
97    *     OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked,
98    *     then some resources may leak.
99    */
opBuilder(String type, String name)100   public OperationBuilder opBuilder(String type, String name) {
101     return new OperationBuilder(this, type, name);
102   }
103 
104   /**
105    * Import a serialized representation of a TensorFlow graph.
106    *
107    * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
108    * generated by {@link #toGraphDef()} and equivalents in other language APIs.
109    *
110    * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
111    * @see #importGraphDef(byte[], String)
112    */
importGraphDef(byte[] graphDef)113   public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
114     importGraphDef(graphDef, "");
115   }
116 
117   /**
118    * Import a serialized representation of a TensorFlow graph.
119    *
120    * @param graphDef the serialized representation of a TensorFlow graph.
121    * @param prefix a prefix that will be prepended to names in graphDef
122    * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
123    * @see #importGraphDef(byte[])
124    */
importGraphDef(byte[] graphDef, String prefix)125   public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
126     if (graphDef == null || prefix == null) {
127       throw new IllegalArgumentException("graphDef and prefix cannot be null");
128     }
129     synchronized (nativeHandleLock) {
130       importGraphDef(nativeHandle, graphDef, prefix);
131     }
132   }
133 
134   /**
135    * Generate a serialized representation of the Graph.
136    *
137    * @see #importGraphDef(byte[])
138    * @see #importGraphDef(byte[], String)
139    */
toGraphDef()140   public byte[] toGraphDef() {
141     synchronized (nativeHandleLock) {
142       return toGraphDef(nativeHandle);
143     }
144   }
145 
146   /**
147    * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
148    * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
149    *
150    * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
151    * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
152    * {@code y}.
153    *
154    * <p>If {@code dx} is null, the implementation will use dx of {@link
155    * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
156    *
157    * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
158    * gradients. It must be unique within the provided graph or the operation will fail.
159    *
160    * <p>If {@code prefix} is null, then one will be chosen automatically.
161    *
162    * @param prefix unique string prefix applied before the names of nodes added to the graph to
163    *     compute gradients. If null, a default one will be chosen.
164    * @param y output of the function to derive
165    * @param x inputs of the function for which partial derivatives are computed
166    * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
167    * @return the partial derivatives {@code dy} with the size of {@code x}
168    */
addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx)169   public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
170     Output<?>[] dy = new Output<?>[x.length];
171     final long[] yHandles = new long[y.length];
172     final int[] yIndices = new int[y.length];
173     final long[] xHandles = new long[x.length];
174     final int[] xIndices = new int[x.length];
175     long[] dxHandles = null;
176     int[] dxIndices = null;
177 
178     try (Reference ref = ref()) {
179       for (int i = 0; i < y.length; ++i) {
180         yHandles[i] = y[i].op().getUnsafeNativeHandle();
181         yIndices[i] = y[i].index();
182       }
183       for (int i = 0; i < x.length; ++i) {
184         xHandles[i] = x[i].op().getUnsafeNativeHandle();
185         xIndices[i] = x[i].index();
186       }
187       if (dx != null && dx.length > 0) {
188         dxHandles = new long[dx.length];
189         dxIndices = new int[dx.length];
190 
191         for (int i = 0; i < dx.length; ++i) {
192           dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
193           dxIndices[i] = dx[i].index();
194         }
195       }
196       // Gradient outputs are returned in two continuous arrays concatenated into one. The first
197       // holds the native handles of the gradient operations while the second holds the index of
198       // their output e.g. given
199       // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
200       // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
201       long[] dyHandlesAndIndices =
202           addGradients(
203               ref.nativeHandle(),
204               prefix,
205               yHandles,
206               yIndices,
207               xHandles,
208               xIndices,
209               dxHandles,
210               dxIndices);
211       int ndy = dyHandlesAndIndices.length >> 1;
212       if (ndy != dy.length) {
213         throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
214             + " were expected");
215       }
216       for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
217         Operation op = new Operation(this, dyHandlesAndIndices[i]);
218         dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
219       }
220     }
221     return dy;
222   }
223 
224   /**
225    * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
226    * i.e., {@code dy/dx_1, dy/dx_2...}
227    * <p>
228    * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
229    * a single output, {@code dx} is null and {@code prefix} is null.
230    *
231    * @param y output of the function to derive
232    * @param x inputs of the function for which partial derivatives are computed
233    * @return the partial derivatives {@code dy} with the size of {@code x}
234    */
addGradients(Output<?> y, Output<?>[] x)235   public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
236     return addGradients(null, new Output<?>[] {y}, x, null);
237   }
238 
239   /**
240    * Used to instantiate an abstract class which overrides the buildSubgraph method to build a
241    * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to
242    * create a lambda for the same purpose.
243    *
244    * <p>To be used when calling {@link #whileLoop(Output[],
245    * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)}
246    *
247    * <p>Example usage (prior to Java 8):
248    *
249    * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public
250    * void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build
251    * body subgraph } }; }
252    *
253    * <p>Example usage (after Java 8):
254    *
255    * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { //
256    * build body subgraph };}
257    */
258   public interface WhileSubgraphBuilder {
259     /**
260      * To be overridden by user with code to build conditional or body subgraph for a while loop
261      *
262      * @param g the subgraph
263      * @param inputs subgraph inputs
264      * @param outputs subgraph outputs
265      */
buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs)266     public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs);
267   }
268 
269   // called by while loop code in graph_jni.cc to construct conditional/body subgraphs
buildSubgraph( WhileSubgraphBuilder subgraphBuilder, long subgraphHandle, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices)270   private static long[] buildSubgraph(
271       WhileSubgraphBuilder subgraphBuilder,
272       long subgraphHandle,
273       long[] inputHandles,
274       int[] inputIndices,
275       long[] outputHandles,
276       int[] outputIndices) {
277     Graph subgraph = new Graph(subgraphHandle);
278 
279     int ninputs = inputHandles.length;
280     int noutputs = outputHandles.length;
281     Output<?>[] inputs = new Output<?>[ninputs];
282     Output<?>[] outputs = new Output<?>[noutputs];
283     long[] outputHandlesAndIndices = new long[noutputs * 2];
284 
285     synchronized (subgraph.nativeHandleLock) {
286       try (Reference ref = subgraph.ref()) {
287 
288         for (int i = 0; i < ninputs; i++) {
289           Operation op = new Operation(subgraph, inputHandles[i]);
290           inputs[i] = new Output<>(op, inputIndices[i]);
291         }
292 
293         for (int i = 0; i < noutputs; i++) {
294           Operation op = new Operation(subgraph, outputHandles[i]);
295           outputs[i] = new Output<>(op, outputIndices[i]);
296         }
297 
298         subgraphBuilder.buildSubgraph(subgraph, inputs, outputs);
299 
300         for (int i = 0, j = noutputs; i < noutputs; i++, j++) {
301           outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle();
302           outputHandlesAndIndices[j] = (long) outputs[i].index();
303         }
304       }
305       return outputHandlesAndIndices;
306     }
307   }
308 
309   /**
310    * Builds a while loop.
311    *
312    * @param inputs the loop inputs
313    * @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph
314    * @param bgBuilder WhileSubgraphBuilder to build the body subgraph
315    * @param name name for the loop
316    * @return list of loop outputs, of the same length as {@code inputs}
317    */
whileLoop( Output<?>[] inputs, WhileSubgraphBuilder cgBuilder, WhileSubgraphBuilder bgBuilder, String name)318   public Output<?>[] whileLoop(
319       Output<?>[] inputs,
320       WhileSubgraphBuilder cgBuilder,
321       WhileSubgraphBuilder bgBuilder,
322       String name) {
323     int ninputs = inputs.length;
324     long[] inputHandles = new long[ninputs];
325     int[] inputIndices = new int[ninputs];
326     Output<?>[] outputs = new Output<?>[ninputs];
327 
328     synchronized (nativeHandleLock) {
329       try (Reference ref = ref()) {
330 
331         for (int i = 0; i < ninputs; i++) {
332           inputHandles[i] = inputs[i].op().getUnsafeNativeHandle();
333           inputIndices[i] = inputs[i].index();
334         }
335 
336         long[] outputHandlesAndIndices =
337             whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder);
338 
339         for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
340           Operation op = new Operation(this, outputHandlesAndIndices[i]);
341           outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]);
342         }
343       }
344       return outputs;
345     }
346   }
347 
348   private final Object nativeHandleLock = new Object();
349   private long nativeHandle;
350   private int refcount = 0;
351 
352   // Related native objects (such as the TF_Operation object backing an Operation instance)
353   // have a validity tied to that of the Graph. The handles to those native objects are not
354   // valid after Graph.close() has been invoked.
355   //
356   // Instances of the Reference class should be used to ensure the Graph has not been closed
357   // while dependent handles are in use.
358   class Reference implements AutoCloseable {
Reference()359     private Reference() {
360       synchronized (Graph.this.nativeHandleLock) {
361         active = Graph.this.nativeHandle != 0;
362         if (!active) {
363           throw new IllegalStateException("close() has been called on the Graph");
364         }
365         active = true;
366         Graph.this.refcount++;
367       }
368     }
369 
370     @Override
close()371     public void close() {
372       synchronized (Graph.this.nativeHandleLock) {
373         if (!active) {
374           return;
375         }
376         active = false;
377         if (--Graph.this.refcount == 0) {
378           Graph.this.nativeHandleLock.notifyAll();
379         }
380       }
381     }
382 
nativeHandle()383     public long nativeHandle() {
384       synchronized (Graph.this.nativeHandleLock) {
385         return active ? Graph.this.nativeHandle : 0;
386       }
387     }
388 
389     private boolean active;
390   }
391 
ref()392   Reference ref() {
393     return new Reference();
394   }
395 
396   private static final class OperationIterator implements Iterator<Operation> {
397 
OperationIterator(Graph g)398     OperationIterator(Graph g) {
399       this.graph = g;
400       this.operation = null;
401       this.position = 0;
402       this.advance();
403     }
404 
advance()405     private final void advance() {
406       Graph.Reference reference = this.graph.ref();
407 
408       this.operation = null;
409 
410       try {
411         long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);
412 
413         if ((nativeReturn != null) && (nativeReturn[0] != 0)) {
414           this.operation = new Operation(this.graph, nativeReturn[0]);
415           this.position = (int) nativeReturn[1];
416         }
417       } finally {
418         reference.close();
419       }
420     }
421 
422     @Override
hasNext()423     public boolean hasNext() {
424       return (this.operation != null);
425     }
426 
427     @Override
next()428     public Operation next() {
429       Operation rhett = this.operation;
430       this.advance();
431       return rhett;
432     }
433 
434     @Override
remove()435     public void remove() {
436       throw new UnsupportedOperationException("remove() is unsupported.");
437     }
438 
439     private final Graph graph;
440     private Operation operation;
441     private int position;
442   }
443 
allocate()444   private static native long allocate();
445 
delete(long handle)446   private static native void delete(long handle);
447 
operation(long handle, String name)448   private static native long operation(long handle, String name);
449 
450   // This method returns the Operation native handle at index 0 and the new value for pos at index 1
451   // (see TF_GraphNextOperation)
nextOperation(long handle, int position)452   private static native long[] nextOperation(long handle, int position);
453 
importGraphDef(long handle, byte[] graphDef, String prefix)454   private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
455       throws IllegalArgumentException;
456 
toGraphDef(long handle)457   private static native byte[] toGraphDef(long handle);
458 
addGradients( long handle, String prefix, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices)459   private static native long[] addGradients(
460       long handle,
461       String prefix,
462       long[] inputHandles,
463       int[] inputIndices,
464       long[] outputHandles,
465       int[] outputIndices,
466       long[] gradInputHandles,
467       int[] gradInputIndices);
468 
whileLoop( long handle, long[] inputHandles, int[] inputIndices, String name, WhileSubgraphBuilder condGraphBuilder, WhileSubgraphBuilder bodyGraphBuilder)469   private static native long[] whileLoop(
470       long handle,
471       long[] inputHandles,
472       int[] inputIndices,
473       String name,
474       WhileSubgraphBuilder condGraphBuilder,
475       WhileSubgraphBuilder bodyGraphBuilder);
476 
477   static {
TensorFlow.init()478     TensorFlow.init();
479   }
480 }
481