1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.util.Iterator; 19 20 /** 21 * A data flow graph representing a TensorFlow computation. 22 * 23 * <p>Instances of a Graph are thread-safe. 24 * 25 * <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking 26 * the {@link #close()} method then the Graph object is no longer needed. 27 */ 28 public final class Graph implements AutoCloseable { 29 30 /** Create an empty Graph. */ Graph()31 public Graph() { 32 nativeHandle = allocate(); 33 } 34 35 /** Create a Graph from an existing handle (takes ownership). */ Graph(long nativeHandle)36 Graph(long nativeHandle) { 37 this.nativeHandle = nativeHandle; 38 } 39 40 /** 41 * Release resources associated with the Graph. 42 * 43 * <p>Blocks until there are no active {@link Session} instances referring to this Graph. A Graph 44 * is not usable after close returns. 45 */ 46 @Override close()47 public void close() { 48 synchronized (nativeHandleLock) { 49 if (nativeHandle == 0) { 50 return; 51 } 52 while (refcount > 0) { 53 try { 54 nativeHandleLock.wait(); 55 } catch (InterruptedException e) { 56 Thread.currentThread().interrupt(); 57 // Possible leak of the graph in this case? 58 return; 59 } 60 } 61 delete(nativeHandle); 62 nativeHandle = 0; 63 } 64 } 65 66 /** 67 * Returns the operation (node in the Graph) with the provided name. 68 * 69 * <p>Or {@code null} if no such operation exists in the Graph. 70 */ operation(String name)71 public Operation operation(String name) { 72 synchronized (nativeHandleLock) { 73 long oph = operation(nativeHandle, name); 74 if (oph == 0) { 75 return null; 76 } 77 return new Operation(this, oph); 78 } 79 } 80 81 /** 82 * Iterator over all the {@link Operation}s in the graph. 83 * 84 * <p>The order of iteration is unspecified. Consumers of the iterator will receive no 85 * notification should the underlying graph change during iteration. 86 */ operations()87 public Iterator<Operation> operations() { 88 return new OperationIterator(this); 89 } 90 91 /** 92 * Returns a builder to add {@link Operation}s to the Graph. 93 * 94 * @param type of the Operation (i.e., identifies the computation to be performed) 95 * @param name to refer to the created Operation in the graph. 96 * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link 97 * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, 98 * then some resources may leak. 99 */ opBuilder(String type, String name)100 public OperationBuilder opBuilder(String type, String name) { 101 return new OperationBuilder(this, type, name); 102 } 103 104 /** 105 * Import a serialized representation of a TensorFlow graph. 106 * 107 * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be 108 * generated by {@link #toGraphDef()} and equivalents in other language APIs. 109 * 110 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph. 111 * @see #importGraphDef(byte[], String) 112 */ importGraphDef(byte[] graphDef)113 public void importGraphDef(byte[] graphDef) throws IllegalArgumentException { 114 importGraphDef(graphDef, ""); 115 } 116 117 /** 118 * Import a serialized representation of a TensorFlow graph. 119 * 120 * @param graphDef the serialized representation of a TensorFlow graph. 121 * @param prefix a prefix that will be prepended to names in graphDef 122 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph. 123 * @see #importGraphDef(byte[]) 124 */ importGraphDef(byte[] graphDef, String prefix)125 public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException { 126 if (graphDef == null || prefix == null) { 127 throw new IllegalArgumentException("graphDef and prefix cannot be null"); 128 } 129 synchronized (nativeHandleLock) { 130 importGraphDef(nativeHandle, graphDef, prefix); 131 } 132 } 133 134 /** 135 * Generate a serialized representation of the Graph. 136 * 137 * @see #importGraphDef(byte[]) 138 * @see #importGraphDef(byte[], String) 139 */ toGraphDef()140 public byte[] toGraphDef() { 141 synchronized (nativeHandleLock) { 142 return toGraphDef(nativeHandle); 143 } 144 } 145 146 /** 147 * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., 148 * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} 149 * 150 * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives 151 * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of 152 * {@code y}. 153 * 154 * <p>If {@code dx} is null, the implementation will use dx of {@link 155 * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. 156 * 157 * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute 158 * gradients. It must be unique within the provided graph or the operation will fail. 159 * 160 * <p>If {@code prefix} is null, then one will be chosen automatically. 161 * 162 * @param prefix unique string prefix applied before the names of nodes added to the graph to 163 * compute gradients. If null, a default one will be chosen. 164 * @param y output of the function to derive 165 * @param x inputs of the function for which partial derivatives are computed 166 * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} 167 * @return the partial derivatives {@code dy} with the size of {@code x} 168 */ addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx)169 public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { 170 Output<?>[] dy = new Output<?>[x.length]; 171 final long[] yHandles = new long[y.length]; 172 final int[] yIndices = new int[y.length]; 173 final long[] xHandles = new long[x.length]; 174 final int[] xIndices = new int[x.length]; 175 long[] dxHandles = null; 176 int[] dxIndices = null; 177 178 try (Reference ref = ref()) { 179 for (int i = 0; i < y.length; ++i) { 180 yHandles[i] = y[i].op().getUnsafeNativeHandle(); 181 yIndices[i] = y[i].index(); 182 } 183 for (int i = 0; i < x.length; ++i) { 184 xHandles[i] = x[i].op().getUnsafeNativeHandle(); 185 xIndices[i] = x[i].index(); 186 } 187 if (dx != null && dx.length > 0) { 188 dxHandles = new long[dx.length]; 189 dxIndices = new int[dx.length]; 190 191 for (int i = 0; i < dx.length; ++i) { 192 dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); 193 dxIndices[i] = dx[i].index(); 194 } 195 } 196 // Gradient outputs are returned in two continuous arrays concatenated into one. The first 197 // holds the native handles of the gradient operations while the second holds the index of 198 // their output e.g. given 199 // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain 200 // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] 201 long[] dyHandlesAndIndices = 202 addGradients( 203 ref.nativeHandle(), 204 prefix, 205 yHandles, 206 yIndices, 207 xHandles, 208 xIndices, 209 dxHandles, 210 dxIndices); 211 int ndy = dyHandlesAndIndices.length >> 1; 212 if (ndy != dy.length) { 213 throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length 214 + " were expected"); 215 } 216 for (int i = 0, j = ndy; i < ndy; ++i, ++j) { 217 Operation op = new Operation(this, dyHandlesAndIndices[i]); 218 dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); 219 } 220 } 221 return dy; 222 } 223 224 /** 225 * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, 226 * i.e., {@code dy/dx_1, dy/dx_2...} 227 * <p> 228 * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is 229 * a single output, {@code dx} is null and {@code prefix} is null. 230 * 231 * @param y output of the function to derive 232 * @param x inputs of the function for which partial derivatives are computed 233 * @return the partial derivatives {@code dy} with the size of {@code x} 234 */ addGradients(Output<?> y, Output<?>[] x)235 public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { 236 return addGradients(null, new Output<?>[] {y}, x, null); 237 } 238 239 /** 240 * Used to instantiate an abstract class which overrides the buildSubgraph method to build a 241 * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to 242 * create a lambda for the same purpose. 243 * 244 * <p>To be used when calling {@link #whileLoop(Output[], 245 * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)} 246 * 247 * <p>Example usage (prior to Java 8): 248 * 249 * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public 250 * void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build 251 * body subgraph } }; } 252 * 253 * <p>Example usage (after Java 8): 254 * 255 * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { // 256 * build body subgraph };} 257 */ 258 public interface WhileSubgraphBuilder { 259 /** 260 * To be overridden by user with code to build conditional or body subgraph for a while loop 261 * 262 * @param g the subgraph 263 * @param inputs subgraph inputs 264 * @param outputs subgraph outputs 265 */ buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs)266 public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs); 267 } 268 269 // called by while loop code in graph_jni.cc to construct conditional/body subgraphs buildSubgraph( WhileSubgraphBuilder subgraphBuilder, long subgraphHandle, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices)270 private static long[] buildSubgraph( 271 WhileSubgraphBuilder subgraphBuilder, 272 long subgraphHandle, 273 long[] inputHandles, 274 int[] inputIndices, 275 long[] outputHandles, 276 int[] outputIndices) { 277 Graph subgraph = new Graph(subgraphHandle); 278 279 int ninputs = inputHandles.length; 280 int noutputs = outputHandles.length; 281 Output<?>[] inputs = new Output<?>[ninputs]; 282 Output<?>[] outputs = new Output<?>[noutputs]; 283 long[] outputHandlesAndIndices = new long[noutputs * 2]; 284 285 synchronized (subgraph.nativeHandleLock) { 286 try (Reference ref = subgraph.ref()) { 287 288 for (int i = 0; i < ninputs; i++) { 289 Operation op = new Operation(subgraph, inputHandles[i]); 290 inputs[i] = new Output<>(op, inputIndices[i]); 291 } 292 293 for (int i = 0; i < noutputs; i++) { 294 Operation op = new Operation(subgraph, outputHandles[i]); 295 outputs[i] = new Output<>(op, outputIndices[i]); 296 } 297 298 subgraphBuilder.buildSubgraph(subgraph, inputs, outputs); 299 300 for (int i = 0, j = noutputs; i < noutputs; i++, j++) { 301 outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle(); 302 outputHandlesAndIndices[j] = (long) outputs[i].index(); 303 } 304 } 305 return outputHandlesAndIndices; 306 } 307 } 308 309 /** 310 * Builds a while loop. 311 * 312 * @param inputs the loop inputs 313 * @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph 314 * @param bgBuilder WhileSubgraphBuilder to build the body subgraph 315 * @param name name for the loop 316 * @return list of loop outputs, of the same length as {@code inputs} 317 */ whileLoop( Output<?>[] inputs, WhileSubgraphBuilder cgBuilder, WhileSubgraphBuilder bgBuilder, String name)318 public Output<?>[] whileLoop( 319 Output<?>[] inputs, 320 WhileSubgraphBuilder cgBuilder, 321 WhileSubgraphBuilder bgBuilder, 322 String name) { 323 int ninputs = inputs.length; 324 long[] inputHandles = new long[ninputs]; 325 int[] inputIndices = new int[ninputs]; 326 Output<?>[] outputs = new Output<?>[ninputs]; 327 328 synchronized (nativeHandleLock) { 329 try (Reference ref = ref()) { 330 331 for (int i = 0; i < ninputs; i++) { 332 inputHandles[i] = inputs[i].op().getUnsafeNativeHandle(); 333 inputIndices[i] = inputs[i].index(); 334 } 335 336 long[] outputHandlesAndIndices = 337 whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder); 338 339 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { 340 Operation op = new Operation(this, outputHandlesAndIndices[i]); 341 outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]); 342 } 343 } 344 return outputs; 345 } 346 } 347 348 private final Object nativeHandleLock = new Object(); 349 private long nativeHandle; 350 private int refcount = 0; 351 352 // Related native objects (such as the TF_Operation object backing an Operation instance) 353 // have a validity tied to that of the Graph. The handles to those native objects are not 354 // valid after Graph.close() has been invoked. 355 // 356 // Instances of the Reference class should be used to ensure the Graph has not been closed 357 // while dependent handles are in use. 358 class Reference implements AutoCloseable { Reference()359 private Reference() { 360 synchronized (Graph.this.nativeHandleLock) { 361 active = Graph.this.nativeHandle != 0; 362 if (!active) { 363 throw new IllegalStateException("close() has been called on the Graph"); 364 } 365 active = true; 366 Graph.this.refcount++; 367 } 368 } 369 370 @Override close()371 public void close() { 372 synchronized (Graph.this.nativeHandleLock) { 373 if (!active) { 374 return; 375 } 376 active = false; 377 if (--Graph.this.refcount == 0) { 378 Graph.this.nativeHandleLock.notifyAll(); 379 } 380 } 381 } 382 nativeHandle()383 public long nativeHandle() { 384 synchronized (Graph.this.nativeHandleLock) { 385 return active ? Graph.this.nativeHandle : 0; 386 } 387 } 388 389 private boolean active; 390 } 391 ref()392 Reference ref() { 393 return new Reference(); 394 } 395 396 private static final class OperationIterator implements Iterator<Operation> { 397 OperationIterator(Graph g)398 OperationIterator(Graph g) { 399 this.graph = g; 400 this.operation = null; 401 this.position = 0; 402 this.advance(); 403 } 404 advance()405 private final void advance() { 406 Graph.Reference reference = this.graph.ref(); 407 408 this.operation = null; 409 410 try { 411 long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); 412 413 if ((nativeReturn != null) && (nativeReturn[0] != 0)) { 414 this.operation = new Operation(this.graph, nativeReturn[0]); 415 this.position = (int) nativeReturn[1]; 416 } 417 } finally { 418 reference.close(); 419 } 420 } 421 422 @Override hasNext()423 public boolean hasNext() { 424 return (this.operation != null); 425 } 426 427 @Override next()428 public Operation next() { 429 Operation rhett = this.operation; 430 this.advance(); 431 return rhett; 432 } 433 434 @Override remove()435 public void remove() { 436 throw new UnsupportedOperationException("remove() is unsupported."); 437 } 438 439 private final Graph graph; 440 private Operation operation; 441 private int position; 442 } 443 allocate()444 private static native long allocate(); 445 delete(long handle)446 private static native void delete(long handle); 447 operation(long handle, String name)448 private static native long operation(long handle, String name); 449 450 // This method returns the Operation native handle at index 0 and the new value for pos at index 1 451 // (see TF_GraphNextOperation) nextOperation(long handle, int position)452 private static native long[] nextOperation(long handle, int position); 453 importGraphDef(long handle, byte[] graphDef, String prefix)454 private static native void importGraphDef(long handle, byte[] graphDef, String prefix) 455 throws IllegalArgumentException; 456 toGraphDef(long handle)457 private static native byte[] toGraphDef(long handle); 458 addGradients( long handle, String prefix, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices)459 private static native long[] addGradients( 460 long handle, 461 String prefix, 462 long[] inputHandles, 463 int[] inputIndices, 464 long[] outputHandles, 465 int[] outputIndices, 466 long[] gradInputHandles, 467 int[] gradInputIndices); 468 whileLoop( long handle, long[] inputHandles, int[] inputIndices, String name, WhileSubgraphBuilder condGraphBuilder, WhileSubgraphBuilder bodyGraphBuilder)469 private static native long[] whileLoop( 470 long handle, 471 long[] inputHandles, 472 int[] inputIndices, 473 String name, 474 WhileSubgraphBuilder condGraphBuilder, 475 WhileSubgraphBuilder bodyGraphBuilder); 476 477 static { TensorFlow.init()478 TensorFlow.init(); 479 } 480 } 481