1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.util.ArrayList; 19 import java.util.List; 20 21 /** 22 * Driver for {@link Graph} execution. 23 * 24 * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a 25 * {@link Graph} are executed to compute {@link Tensor Tensors}. For example: 26 * 27 * <pre>{@code 28 * // Let's say graph is an instance of the Graph class 29 * // for the computation y = 3 * x 30 * 31 * try (Session s = new Session(graph)) { 32 * try (Tensor x = Tensor.create(2.0f); 33 * Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 34 * System.out.println(y.floatValue()); // Will print 6.0f 35 * } 36 * try (Tensor x = Tensor.create(1.1f); 37 * Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 38 * System.out.println(y.floatValue()); // Will print 3.3f 39 * } 40 * } 41 * }</pre> 42 * 43 * <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by 44 * invoking {@link #close()}. 45 * 46 * <p>Instances of a Session are thread-safe. 47 */ 48 public final class Session implements AutoCloseable { 49 50 /** Construct a new session with the associated {@link Graph}. */ Session(Graph g)51 public Session(Graph g) { 52 this(g, null); 53 } 54 55 /** 56 * Construct a new session with the associated {@link Graph} and configuration options. 57 * 58 * @param g The {@link Graph} the created Session will operate on. 59 * @param config Configuration parameters for the session specified as a serialized <a 60 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto</a> 61 * protocol buffer. 62 * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto 63 * protocol buffer. 64 */ Session(Graph g, byte[] config)65 public Session(Graph g, byte[] config) { 66 graph = g; 67 Graph.Reference r = g.ref(); 68 try { 69 nativeHandle = 70 (config == null) ? allocate(r.nativeHandle()) : allocate2(r.nativeHandle(), null, config); 71 graphRef = g.ref(); 72 } finally { 73 r.close(); 74 } 75 } 76 77 /** Wrap an existing session with the associated {@link Graph}. */ Session(Graph g, long nativeHandle)78 Session(Graph g, long nativeHandle) { 79 graph = g; 80 this.nativeHandle = nativeHandle; 81 graphRef = g.ref(); 82 } 83 84 /** 85 * Release resources associated with the Session. 86 * 87 * <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session 88 * is not usable after close returns. 89 */ 90 @Override close()91 public void close() { 92 graphRef.close(); 93 synchronized (nativeHandleLock) { 94 if (nativeHandle == 0) { 95 return; 96 } 97 while (numActiveRuns > 0) { 98 try { 99 nativeHandleLock.wait(); 100 } catch (InterruptedException e) { 101 Thread.currentThread().interrupt(); 102 // Possible leak of the Session and Graph in this case? 103 return; 104 } 105 } 106 delete(nativeHandle); 107 nativeHandle = 0; 108 } 109 } 110 111 /** 112 * Run {@link Operation}s and evaluate {@link Tensor Tensors}. 113 * 114 * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to 115 * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows 116 * callers to override the value of {@link Tensor Tensors} in the graph by substituting the 117 * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link 118 * #feed(String,int,Tensor)}. 119 */ 120 public final class Runner { 121 /** 122 * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. 123 * 124 * @param operation Is either the string name of the operation, in which case this method is a 125 * shorthand for {@code feed(operation, 0)}, or it is a string of the form 126 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code 127 * feed(operation_name, output_index)}. These colon-separated names are commonly used in the 128 * {@code SignatureDef} protocol buffer messages that are included in {@link 129 * SavedModelBundle#metaGraphDef()}. 130 */ feed(String operation, Tensor<?> t)131 public Runner feed(String operation, Tensor<?> t) { 132 return feed(parseOutput(operation), t); 133 } 134 135 /** 136 * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} 137 * for the value it produces. 138 * 139 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which 140 * one {@code t} is being provided for. 141 */ feed(String operation, int index, Tensor<?> t)142 public Runner feed(String operation, int index, Tensor<?> t) { 143 Operation op = operationByName(operation); 144 if (op != null) { 145 inputs.add(op.output(index)); 146 inputTensors.add(t); 147 } 148 return this; 149 } 150 151 /** 152 * Use {@code t} instead of the Tensor referred to by executing the operation referred to by 153 * {@code operand}. 154 */ feed(Operand<?> operand, Tensor<?> t)155 public Runner feed(Operand<?> operand, Tensor<?> t) { 156 inputs.add(operand.asOutput()); 157 inputTensors.add(t); 158 return this; 159 } 160 161 /** 162 * Make {@link #run()} return the output of {@code operation}. 163 * 164 * @param operation Is either the string name of the operation, in which case this method is a 165 * shorthand for {@code fetch(operation, 0)}, or it is a string of the form 166 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code 167 * fetch(operation_name, output_index)}. These colon-separated names are commonly used in 168 * the {@code SignatureDef} protocol buffer messages that are included in {@link 169 * SavedModelBundle#metaGraphDef()}. 170 */ fetch(String operation)171 public Runner fetch(String operation) { 172 return fetch(parseOutput(operation)); 173 } 174 175 /** 176 * Make {@link #run()} return the {@code index}-th output of {@code operation}. 177 * 178 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which 179 * one to return. 180 */ fetch(String operation, int index)181 public Runner fetch(String operation, int index) { 182 Operation op = operationByName(operation); 183 if (op != null) { 184 outputs.add(op.output(index)); 185 } 186 return this; 187 } 188 189 /** 190 * Makes {@link #run()} return the Tensor referred to by {@code output}. 191 */ fetch(Output<?> output)192 public Runner fetch(Output<?> output) { 193 outputs.add(output); 194 return this; 195 } 196 197 /** 198 * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. 199 */ fetch(Operand<?> operand)200 public Runner fetch(Operand<?> operand) { 201 return fetch(operand.asOutput()); 202 } 203 204 /** 205 * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor 206 * Tensors}. 207 */ addTarget(String operation)208 public Runner addTarget(String operation) { 209 GraphOperation op = operationByName(operation); 210 if (op != null) { 211 targets.add(op); 212 } 213 return this; 214 } 215 216 /** 217 * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor 218 * Tensors}. 219 * 220 * @throws IllegalArgumentException if the operation is not a {@link GraphOperation} 221 */ addTarget(Operation operation)222 public Runner addTarget(Operation operation) { 223 if (!(operation instanceof GraphOperation)) { 224 throw new IllegalArgumentException( 225 "Operation of type " 226 + operation.getClass().getName() 227 + " is not supported in graph sessions"); 228 } 229 targets.add((GraphOperation) operation); 230 return this; 231 } 232 233 /** 234 * Make {@link #run} execute {@code operand}, but not return any evaluated {@link Tensor 235 * Tensors}. 236 */ addTarget(Operand<?> operand)237 public Runner addTarget(Operand<?> operand) { 238 return addTarget(operand.asOutput().op()); 239 } 240 241 /** 242 * (Experimental method): set options (typically for debugging) for this run. 243 * 244 * <p>The options are presented as a serialized <a 245 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions 246 * protocol buffer</a>. 247 * 248 * <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain 249 * friendly to resource constrained systems (where something like <a 250 * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may 251 * be more appropriate). A cost of that is this lack of type-safety in this API function. This 252 * choice is under review and this function may be replaced by more type-safe equivalents at any 253 * time. 254 */ setOptions(byte[] options)255 public Runner setOptions(byte[] options) { 256 this.runOptions = options; 257 return this; 258 } 259 260 /** 261 * Execute the graph fragments necessary to compute all requested fetches. 262 * 263 * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., 264 * the caller must call {@link Tensor#close} on all elements of the returned list to free up 265 * resources. 266 * 267 * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it 268 * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in 269 * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a 270 * {@code Map<Output, Tensor>}? 271 * 272 * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to 273 * extract output tensors in a type-safe way. 274 */ run()275 public List<Tensor<?>> run() { 276 return runHelper(false).outputs; 277 } 278 279 /** 280 * Execute graph fragments to compute requested fetches and return metadata about the run. 281 * 282 * <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also 283 * returns metadata about the graph execution in the form of a serialized <a 284 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata 285 * protocol buffer</a>. 286 */ runAndFetchMetadata()287 public Run runAndFetchMetadata() { 288 return runHelper(true); 289 } 290 runHelper(boolean wantMetadata)291 private Run runHelper(boolean wantMetadata) { 292 long[] inputTensorHandles = new long[inputTensors.size()]; 293 long[] inputOpHandles = new long[inputs.size()]; 294 int[] inputOpIndices = new int[inputs.size()]; 295 long[] outputOpHandles = new long[outputs.size()]; 296 int[] outputOpIndices = new int[outputs.size()]; 297 long[] targetOpHandles = new long[targets.size()]; 298 long[] outputTensorHandles = new long[outputs.size()]; 299 300 // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the 301 // validity of the Graph and graphRef ensures that. 302 int idx = 0; 303 for (Tensor<?> t : inputTensors) { 304 inputTensorHandles[idx++] = t.getNativeHandle(); 305 } 306 idx = 0; 307 for (Output<?> o : inputs) { 308 inputOpHandles[idx] = o.getUnsafeNativeHandle(); 309 inputOpIndices[idx] = o.index(); 310 idx++; 311 } 312 idx = 0; 313 for (Output<?> o : outputs) { 314 outputOpHandles[idx] = o.getUnsafeNativeHandle(); 315 outputOpIndices[idx] = o.index(); 316 idx++; 317 } 318 idx = 0; 319 for (GraphOperation op : targets) { 320 targetOpHandles[idx++] = op.getUnsafeNativeHandle(); 321 } 322 Reference runRef = new Reference(); 323 byte[] metadata = null; 324 try { 325 metadata = 326 Session.run( 327 nativeHandle, 328 runOptions, 329 inputTensorHandles, 330 inputOpHandles, 331 inputOpIndices, 332 outputOpHandles, 333 outputOpIndices, 334 targetOpHandles, 335 wantMetadata, 336 outputTensorHandles); 337 } finally { 338 runRef.close(); 339 } 340 List<Tensor<?>> outputs = new ArrayList<Tensor<?>>(); 341 for (long h : outputTensorHandles) { 342 try { 343 outputs.add(Tensor.fromHandle(h)); 344 } catch (Exception e) { 345 for (Tensor<?> t : outputs) { 346 t.close(); 347 } 348 outputs.clear(); 349 throw e; 350 } 351 } 352 Run ret = new Run(); 353 ret.outputs = outputs; 354 ret.metadata = metadata; 355 return ret; 356 } 357 358 private class Reference implements AutoCloseable { Reference()359 public Reference() { 360 synchronized (nativeHandleLock) { 361 if (nativeHandle == 0) { 362 throw new IllegalStateException("run() cannot be called on the Session after close()"); 363 } 364 ++numActiveRuns; 365 } 366 } 367 368 @Override close()369 public void close() { 370 synchronized (nativeHandleLock) { 371 if (nativeHandle == 0) { 372 return; 373 } 374 if (--numActiveRuns == 0) { 375 nativeHandleLock.notifyAll(); 376 } 377 } 378 } 379 } 380 operationByName(String opName)381 private GraphOperation operationByName(String opName) { 382 GraphOperation op = graph.operation(opName); 383 if (op == null) { 384 throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); 385 } 386 return op; 387 } 388 389 @SuppressWarnings("rawtypes") parseOutput(String opName)390 private Output<?> parseOutput(String opName) { 391 int colon = opName.lastIndexOf(':'); 392 if (colon == -1 || colon == opName.length() - 1) { 393 return new Output(operationByName(opName), 0); 394 } 395 try { 396 String op = opName.substring(0, colon); 397 int index = Integer.parseInt(opName.substring(colon + 1)); 398 return new Output(operationByName(op), index); 399 } catch (NumberFormatException e) { 400 return new Output(operationByName(opName), 0); 401 } 402 } 403 404 private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>(); 405 private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>(); 406 private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>(); 407 private ArrayList<GraphOperation> targets = new ArrayList<GraphOperation>(); 408 private byte[] runOptions = null; 409 } 410 411 /** Create a Runner to execute graph operations and evaluate Tensors. */ runner()412 public Runner runner() { 413 return new Runner(); 414 } 415 416 /** 417 * Output tensors and metadata obtained when executing a session. 418 * 419 * <p>See {@link Runner#runAndFetchMetadata()} 420 */ 421 public static final class Run { 422 /** Tensors from requested fetches. */ 423 public List<Tensor<?>> outputs; 424 425 /** 426 * (Experimental): Metadata about the run. 427 * 428 * <p>A serialized <a 429 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata 430 * protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies 431 * in order to remain friendly to resource constrained systems (where something like <a 432 * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may 433 * be more appropriate). A cost of that is this opaque blob. This choice is under review and 434 * this field may be replaced by more type-safe equivalents at any time. 435 */ 436 public byte[] metadata; 437 } 438 439 private final Graph graph; 440 private final Graph.Reference graphRef; 441 442 private final Object nativeHandleLock = new Object(); 443 private long nativeHandle; 444 private int numActiveRuns; 445 446 // TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2(). allocate(long graphHandle)447 private static native long allocate(long graphHandle); 448 allocate2(long graphHandle, String target, byte[] config)449 private static native long allocate2(long graphHandle, String target, byte[] config); 450 delete(long handle)451 private static native void delete(long handle); 452 453 /** 454 * Execute a session. 455 * 456 * <p>The author apologizes for the ugliness of the long argument list of this method. However, 457 * take solace in the fact that this is a private method meant to cross the JNI boundary. 458 * 459 * @param handle to the C API TF_Session object (Session.nativeHandle) 460 * @param runOptions serialized representation of a RunOptions protocol buffer, or null 461 * @param inputOpHandles (see inputOpIndices) 462 * @param inputOpIndices (see inputTensorHandles) 463 * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values 464 * that are being "fed" (do not need to be computed) during graph execution. 465 * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the 466 * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that 467 * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. 468 * @param outputOpHandles (see outputOpIndices) 469 * @param outputOpIndices together with outputOpHandles identifies the set of values that should 470 * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is 471 * required that outputOpHandles.length == outputOpIndices.length. 472 * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose 473 * output will not be returned 474 * @param wantRunMetadata indicates whether metadata about this execution should be returned. 475 * @param outputTensorHandles will be filled in with handles to the outputs requested. It is 476 * required that outputTensorHandles.length == outputOpHandles.length. 477 * @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol 478 * buffer, false otherwise. 479 */ run( long handle, byte[] runOptions, long[] inputTensorHandles, long[] inputOpHandles, int[] inputOpIndices, long[] outputOpHandles, int[] outputOpIndices, long[] targetOpHandles, boolean wantRunMetadata, long[] outputTensorHandles)480 private static native byte[] run( 481 long handle, 482 byte[] runOptions, 483 long[] inputTensorHandles, 484 long[] inputOpHandles, 485 int[] inputOpIndices, 486 long[] outputOpHandles, 487 int[] outputOpIndices, 488 long[] targetOpHandles, 489 boolean wantRunMetadata, 490 long[] outputTensorHandles); 491 } 492