1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.util.Iterator; 19 20 /** 21 * A data flow graph representing a TensorFlow computation. 22 * 23 * <p>Instances of a Graph are thread-safe. 24 * 25 * <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking 26 * the {@link #close()} method then the Graph object is no longer needed. 27 */ 28 public final class Graph implements ExecutionEnvironment, AutoCloseable { 29 30 /** Create an empty Graph. */ Graph()31 public Graph() { 32 nativeHandle = allocate(); 33 } 34 35 /** Create a Graph from an existing handle (takes ownership). */ Graph(long nativeHandle)36 Graph(long nativeHandle) { 37 this.nativeHandle = nativeHandle; 38 } 39 40 /** 41 * Release resources associated with the Graph. 42 * 43 * <p>Blocks until there are no active {@link Session} instances referring to this Graph. A Graph 44 * is not usable after close returns. 45 */ 46 @Override close()47 public void close() { 48 synchronized (nativeHandleLock) { 49 if (nativeHandle == 0) { 50 return; 51 } 52 while (refcount > 0) { 53 try { 54 nativeHandleLock.wait(); 55 } catch (InterruptedException e) { 56 Thread.currentThread().interrupt(); 57 // Possible leak of the graph in this case? 58 return; 59 } 60 } 61 delete(nativeHandle); 62 nativeHandle = 0; 63 } 64 } 65 66 /** 67 * Returns the operation (node in the Graph) with the provided name. 68 * 69 * <p>Or {@code null} if no such operation exists in the Graph. 70 */ operation(String name)71 public GraphOperation operation(String name) { 72 synchronized (nativeHandleLock) { 73 long oph = operation(nativeHandle, name); 74 if (oph == 0) { 75 return null; 76 } 77 return new GraphOperation(this, oph); 78 } 79 } 80 81 /** 82 * Iterator over all the {@link Operation}s in the graph. 83 * 84 * <p>The order of iteration is unspecified. Consumers of the iterator will receive no 85 * notification should the underlying graph change during iteration. 86 */ operations()87 public Iterator<Operation> operations() { 88 return new OperationIterator(this); 89 } 90 91 /** 92 * Returns a builder to add {@link Operation}s to the Graph. 93 * 94 * @param type of the Operation (i.e., identifies the computation to be performed) 95 * @param name to refer to the created Operation in the graph. 96 * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link 97 * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, 98 * then some resources may leak. 99 */ 100 @Override opBuilder(String type, String name)101 public GraphOperationBuilder opBuilder(String type, String name) { 102 return new GraphOperationBuilder(this, type, name); 103 } 104 105 /** 106 * Import a serialized representation of a TensorFlow graph. 107 * 108 * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be 109 * generated by {@link #toGraphDef()} and equivalents in other language APIs. 110 * 111 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph. 112 * @see #importGraphDef(byte[], String) 113 */ importGraphDef(byte[] graphDef)114 public void importGraphDef(byte[] graphDef) throws IllegalArgumentException { 115 importGraphDef(graphDef, ""); 116 } 117 118 /** 119 * Import a serialized representation of a TensorFlow graph. 120 * 121 * @param graphDef the serialized representation of a TensorFlow graph. 122 * @param prefix a prefix that will be prepended to names in graphDef 123 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph. 124 * @see #importGraphDef(byte[]) 125 */ importGraphDef(byte[] graphDef, String prefix)126 public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException { 127 if (graphDef == null || prefix == null) { 128 throw new IllegalArgumentException("graphDef and prefix cannot be null"); 129 } 130 synchronized (nativeHandleLock) { 131 importGraphDef(nativeHandle, graphDef, prefix); 132 } 133 } 134 135 /** 136 * Generate a serialized representation of the Graph. 137 * 138 * @see #importGraphDef(byte[]) 139 * @see #importGraphDef(byte[], String) 140 */ toGraphDef()141 public byte[] toGraphDef() { 142 synchronized (nativeHandleLock) { 143 return toGraphDef(nativeHandle); 144 } 145 } 146 147 /** 148 * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., 149 * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} 150 * 151 * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives 152 * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of 153 * {@code y}. 154 * 155 * <p>If {@code dx} is null, the implementation will use dx of {@link 156 * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. 157 * 158 * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute 159 * gradients. It must be unique within the provided graph or the operation will fail. 160 * 161 * <p>If {@code prefix} is null, then one will be chosen automatically. 162 * 163 * @param prefix unique string prefix applied before the names of nodes added to the graph to 164 * compute gradients. If null, a default one will be chosen. 165 * @param y output of the function to derive 166 * @param x inputs of the function for which partial derivatives are computed 167 * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} 168 * @return the partial derivatives {@code dy} with the size of {@code x} 169 */ addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx)170 public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { 171 Output<?>[] dy = new Output<?>[x.length]; 172 final long[] yHandles = new long[y.length]; 173 final int[] yIndices = new int[y.length]; 174 final long[] xHandles = new long[x.length]; 175 final int[] xIndices = new int[x.length]; 176 long[] dxHandles = null; 177 int[] dxIndices = null; 178 179 try (Reference ref = ref()) { 180 for (int i = 0; i < y.length; ++i) { 181 yHandles[i] = y[i].getUnsafeNativeHandle(); 182 yIndices[i] = y[i].index(); 183 } 184 for (int i = 0; i < x.length; ++i) { 185 xHandles[i] = x[i].getUnsafeNativeHandle(); 186 xIndices[i] = x[i].index(); 187 } 188 if (dx != null && dx.length > 0) { 189 dxHandles = new long[dx.length]; 190 dxIndices = new int[dx.length]; 191 192 for (int i = 0; i < dx.length; ++i) { 193 dxHandles[i] = dx[i].getUnsafeNativeHandle(); 194 dxIndices[i] = dx[i].index(); 195 } 196 } 197 // Gradient outputs are returned in two continuous arrays concatenated into one. The first 198 // holds the native handles of the gradient operations while the second holds the index of 199 // their output e.g. given 200 // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain 201 // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] 202 long[] dyHandlesAndIndices = 203 addGradients( 204 ref.nativeHandle(), 205 prefix, 206 yHandles, 207 yIndices, 208 xHandles, 209 xIndices, 210 dxHandles, 211 dxIndices); 212 int ndy = dyHandlesAndIndices.length >> 1; 213 if (ndy != dy.length) { 214 throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length 215 + " were expected"); 216 } 217 for (int i = 0, j = ndy; i < ndy; ++i, ++j) { 218 GraphOperation op = new GraphOperation(this, dyHandlesAndIndices[i]); 219 dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); 220 } 221 } 222 return dy; 223 } 224 225 /** 226 * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, 227 * i.e., {@code dy/dx_1, dy/dx_2...} 228 * <p> 229 * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} 230 * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null. 231 * 232 * @param y output of the function to derive 233 * @param x inputs of the function for which partial derivatives are computed 234 * @return the partial derivatives {@code dy} with the size of {@code x} 235 */ addGradients(Output<?> y, Output<?>[] x)236 public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { 237 return addGradients(null, new Output<?>[] {y}, x, null); 238 } 239 240 /** 241 * Used to instantiate an abstract class which overrides the buildSubgraph method to build a 242 * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to 243 * create a lambda for the same purpose. 244 * 245 * <p>To be used when calling {@link #whileLoop(Output[], 246 * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)} 247 * 248 * <p>Example usage (prior to Java 8): 249 * 250 * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = new WhileSubgraphBuilder() { @Override public 251 * void buildSubgraph(Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { // build 252 * body subgraph } }; } 253 * 254 * <p>Example usage (after Java 8): 255 * 256 * <p>{@code WhileSubgraphBuilder bodyGraphBuilder = (bodyGraph, bodyInputs, bodyOutputs) -> { // 257 * build body subgraph };} 258 */ 259 public interface WhileSubgraphBuilder { 260 /** 261 * To be overridden by user with code to build conditional or body subgraph for a while loop 262 * 263 * @param g the subgraph 264 * @param inputs subgraph inputs 265 * @param outputs subgraph outputs 266 */ buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs)267 public void buildSubgraph(Graph g, Output<?>[] inputs, Output<?>[] outputs); 268 } 269 270 // called by while loop code in graph_jni.cc to construct conditional/body subgraphs buildSubgraph( WhileSubgraphBuilder subgraphBuilder, long subgraphHandle, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices)271 private static long[] buildSubgraph( 272 WhileSubgraphBuilder subgraphBuilder, 273 long subgraphHandle, 274 long[] inputHandles, 275 int[] inputIndices, 276 long[] outputHandles, 277 int[] outputIndices) { 278 Graph subgraph = new Graph(subgraphHandle); 279 280 int ninputs = inputHandles.length; 281 int noutputs = outputHandles.length; 282 Output<?>[] inputs = new Output<?>[ninputs]; 283 Output<?>[] outputs = new Output<?>[noutputs]; 284 long[] outputHandlesAndIndices = new long[noutputs * 2]; 285 286 synchronized (subgraph.nativeHandleLock) { 287 try (Reference ref = subgraph.ref()) { 288 289 for (int i = 0; i < ninputs; i++) { 290 Operation op = new GraphOperation(subgraph, inputHandles[i]); 291 inputs[i] = op.output(inputIndices[i]); 292 } 293 294 for (int i = 0; i < noutputs; i++) { 295 Operation op = new GraphOperation(subgraph, outputHandles[i]); 296 outputs[i] = op.output(outputIndices[i]); 297 } 298 299 subgraphBuilder.buildSubgraph(subgraph, inputs, outputs); 300 301 for (int i = 0, j = noutputs; i < noutputs; i++, j++) { 302 outputHandlesAndIndices[i] = outputs[i].getUnsafeNativeHandle(); 303 outputHandlesAndIndices[j] = (long) outputs[i].index(); 304 } 305 } 306 return outputHandlesAndIndices; 307 } 308 } 309 310 /** 311 * Builds a while loop. 312 * 313 * @param inputs the loop inputs 314 * @param cgBuilder WhileSubgraphBuilder to build the conditional subgraph 315 * @param bgBuilder WhileSubgraphBuilder to build the body subgraph 316 * @param name name for the loop 317 * @return list of loop outputs, of the same length as {@code inputs} 318 */ whileLoop( Output<?>[] inputs, WhileSubgraphBuilder cgBuilder, WhileSubgraphBuilder bgBuilder, String name)319 public Output<?>[] whileLoop( 320 Output<?>[] inputs, 321 WhileSubgraphBuilder cgBuilder, 322 WhileSubgraphBuilder bgBuilder, 323 String name) { 324 int ninputs = inputs.length; 325 long[] inputHandles = new long[ninputs]; 326 int[] inputIndices = new int[ninputs]; 327 Output<?>[] outputs = new Output<?>[ninputs]; 328 329 synchronized (nativeHandleLock) { 330 try (Reference ref = ref()) { 331 332 for (int i = 0; i < ninputs; i++) { 333 inputHandles[i] = inputs[i].getUnsafeNativeHandle(); 334 inputIndices[i] = inputs[i].index(); 335 } 336 337 long[] outputHandlesAndIndices = 338 whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder); 339 340 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { 341 Operation op = new GraphOperation(this, outputHandlesAndIndices[i]); 342 outputs[i] = op.output((int) outputHandlesAndIndices[j]); 343 } 344 } 345 return outputs; 346 } 347 } 348 349 private final Object nativeHandleLock = new Object(); 350 private long nativeHandle; 351 private int refcount = 0; 352 353 // Related native objects (such as the TF_Operation object backing an Operation instance) 354 // have a validity tied to that of the Graph. The handles to those native objects are not 355 // valid after Graph.close() has been invoked. 356 // 357 // Instances of the Reference class should be used to ensure the Graph has not been closed 358 // while dependent handles are in use. 359 class Reference implements AutoCloseable { Reference()360 private Reference() { 361 synchronized (Graph.this.nativeHandleLock) { 362 active = Graph.this.nativeHandle != 0; 363 if (!active) { 364 throw new IllegalStateException("close() has been called on the Graph"); 365 } 366 active = true; 367 Graph.this.refcount++; 368 } 369 } 370 371 @Override close()372 public void close() { 373 synchronized (Graph.this.nativeHandleLock) { 374 if (!active) { 375 return; 376 } 377 active = false; 378 if (--Graph.this.refcount == 0) { 379 Graph.this.nativeHandleLock.notifyAll(); 380 } 381 } 382 } 383 nativeHandle()384 public long nativeHandle() { 385 synchronized (Graph.this.nativeHandleLock) { 386 return active ? Graph.this.nativeHandle : 0; 387 } 388 } 389 390 private boolean active; 391 } 392 ref()393 Reference ref() { 394 return new Reference(); 395 } 396 397 private static final class OperationIterator implements Iterator<Operation> { 398 OperationIterator(Graph g)399 OperationIterator(Graph g) { 400 this.graph = g; 401 this.operation = null; 402 this.position = 0; 403 this.advance(); 404 } 405 advance()406 private final void advance() { 407 Graph.Reference reference = this.graph.ref(); 408 409 this.operation = null; 410 411 try { 412 long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); 413 414 if ((nativeReturn != null) && (nativeReturn[0] != 0)) { 415 this.operation = new GraphOperation(this.graph, nativeReturn[0]); 416 this.position = (int) nativeReturn[1]; 417 } 418 } finally { 419 reference.close(); 420 } 421 } 422 423 @Override hasNext()424 public boolean hasNext() { 425 return (this.operation != null); 426 } 427 428 @Override next()429 public Operation next() { 430 Operation rhett = this.operation; 431 this.advance(); 432 return rhett; 433 } 434 435 @Override remove()436 public void remove() { 437 throw new UnsupportedOperationException("remove() is unsupported."); 438 } 439 440 private final Graph graph; 441 private Operation operation; 442 private int position; 443 } 444 allocate()445 private static native long allocate(); 446 delete(long handle)447 private static native void delete(long handle); 448 operation(long handle, String name)449 private static native long operation(long handle, String name); 450 451 // This method returns the Operation native handle at index 0 and the new value for pos at index 1 452 // (see TF_GraphNextOperation) nextOperation(long handle, int position)453 private static native long[] nextOperation(long handle, int position); 454 importGraphDef(long handle, byte[] graphDef, String prefix)455 private static native void importGraphDef(long handle, byte[] graphDef, String prefix) 456 throws IllegalArgumentException; 457 toGraphDef(long handle)458 private static native byte[] toGraphDef(long handle); 459 addGradients( long handle, String prefix, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices)460 private static native long[] addGradients( 461 long handle, 462 String prefix, 463 long[] inputHandles, 464 int[] inputIndices, 465 long[] outputHandles, 466 int[] outputIndices, 467 long[] gradInputHandles, 468 int[] gradInputIndices); 469 whileLoop( long handle, long[] inputHandles, int[] inputIndices, String name, WhileSubgraphBuilder condGraphBuilder, WhileSubgraphBuilder bodyGraphBuilder)470 private static native long[] whileLoop( 471 long handle, 472 long[] inputHandles, 473 int[] inputIndices, 474 String name, 475 WhileSubgraphBuilder condGraphBuilder, 476 WhileSubgraphBuilder bodyGraphBuilder); 477 478 static { TensorFlow.init()479 TensorFlow.init(); 480 } 481 } 482