• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import java.util.ArrayList;
19 import java.util.List;
20 
21 /**
22  * Driver for {@link Graph} execution.
23  *
24  * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a
25  * {@link Graph} are executed to compute {@link Tensor Tensors}. For example:
26  *
27  * <pre>{@code
28  * // Let's say graph is an instance of the Graph class
29  * // for the computation y = 3 * x
30  *
31  * try (Session s = new Session(graph)) {
32  *   try (Tensor x = Tensor.create(2.0f);
33  *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
34  *       System.out.println(y.floatValue());  // Will print 6.0f
35  *   }
36  *   try (Tensor x = Tensor.create(1.1f);
37  *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
38  *       System.out.println(y.floatValue());  // Will print 3.3f
39  *   }
40  * }
41  * }</pre>
42  *
43  * <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by
44  * invoking {@link #close()}.
45  *
46  * <p>Instances of a Session are thread-safe.
47  */
48 public final class Session implements AutoCloseable {
49 
50   /** Construct a new session with the associated {@link Graph}. */
Session(Graph g)51   public Session(Graph g) {
52     this(g, null);
53   }
54 
55   /**
56    * Construct a new session with the associated {@link Graph} and configuration options.
57    *
58    * @param g The {@link Graph} the created Session will operate on.
59    * @param config Configuration parameters for the session specified as a serialized <a
60    *     href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto</a>
61    *     protocol buffer.
62    * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto
63    *     protocol buffer.
64    */
Session(Graph g, byte[] config)65   public Session(Graph g, byte[] config) {
66     graph = g;
67     Graph.Reference r = g.ref();
68     try {
69       nativeHandle =
70           (config == null) ? allocate(r.nativeHandle()) : allocate2(r.nativeHandle(), null, config);
71       graphRef = g.ref();
72     } finally {
73       r.close();
74     }
75   }
76 
77   /** Wrap an existing session with the associated {@link Graph}. */
Session(Graph g, long nativeHandle)78   Session(Graph g, long nativeHandle) {
79     graph = g;
80     this.nativeHandle = nativeHandle;
81     graphRef = g.ref();
82   }
83 
84   /**
85    * Release resources associated with the Session.
86    *
87    * <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session
88    * is not usable after close returns.
89    */
90   @Override
close()91   public void close() {
92     graphRef.close();
93     synchronized (nativeHandleLock) {
94       if (nativeHandle == 0) {
95         return;
96       }
97       while (numActiveRuns > 0) {
98         try {
99           nativeHandleLock.wait();
100         } catch (InterruptedException e) {
101           Thread.currentThread().interrupt();
102           // Possible leak of the Session and Graph in this case?
103           return;
104         }
105       }
106       delete(nativeHandle);
107       nativeHandle = 0;
108     }
109   }
110 
111   /**
112    * Run {@link Operation}s and evaluate {@link Tensor Tensors}.
113    *
114    * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to
115    * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows
116    * callers to override the value of {@link Tensor Tensors} in the graph by substituting the
117    * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link
118    * #feed(String,int,Tensor)}.
119    */
120   public final class Runner {
121     /**
122      * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces.
123      *
124      * @param operation Is either the string name of the operation, in which case this method is a
125      *     shorthand for {@code feed(operation, 0)}, or it is a string of the form
126      *     <tt>operation_name:output_index</tt> , in which case this method acts like {@code
127      *     feed(operation_name, output_index)}. These colon-separated names are commonly used in the
128      *     {@code SignatureDef} protocol buffer messages that are included in {@link
129      *     SavedModelBundle#metaGraphDef()}.
130      */
feed(String operation, Tensor<?> t)131     public Runner feed(String operation, Tensor<?> t) {
132       return feed(parseOutput(operation), t);
133     }
134 
135     /**
136      * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t}
137      * for the value it produces.
138      *
139      * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
140      * one {@code t} is being provided for.
141      */
feed(String operation, int index, Tensor<?> t)142     public Runner feed(String operation, int index, Tensor<?> t) {
143       Operation op = operationByName(operation);
144       if (op != null) {
145         inputs.add(op.output(index));
146         inputTensors.add(t);
147       }
148       return this;
149     }
150 
151     /**
152      * Use {@code t} instead of the Tensor referred to by executing the operation referred to by
153      * {@code operand}.
154      */
feed(Operand<?> operand, Tensor<?> t)155     public Runner feed(Operand<?> operand, Tensor<?> t) {
156       inputs.add(operand.asOutput());
157       inputTensors.add(t);
158       return this;
159     }
160 
161     /**
162      * Make {@link #run()} return the output of {@code operation}.
163      *
164      * @param operation Is either the string name of the operation, in which case this method is a
165      *     shorthand for {@code fetch(operation, 0)}, or it is a string of the form
166      *     <tt>operation_name:output_index</tt> , in which case this method acts like {@code
167      *     fetch(operation_name, output_index)}. These colon-separated names are commonly used in
168      *     the {@code SignatureDef} protocol buffer messages that are included in {@link
169      *     SavedModelBundle#metaGraphDef()}.
170      */
fetch(String operation)171     public Runner fetch(String operation) {
172       return fetch(parseOutput(operation));
173     }
174 
175     /**
176      * Make {@link #run()} return the {@code index}-th output of {@code operation}.
177      *
178      * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
179      * one to return.
180      */
fetch(String operation, int index)181     public Runner fetch(String operation, int index) {
182       Operation op = operationByName(operation);
183       if (op != null) {
184         outputs.add(op.output(index));
185       }
186       return this;
187     }
188 
189     /**
190      * Makes {@link #run()} return the Tensor referred to by {@code output}.
191      */
fetch(Output<?> output)192     public Runner fetch(Output<?> output) {
193       outputs.add(output);
194       return this;
195     }
196 
197     /**
198      * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
199      */
fetch(Operand<?> operand)200     public Runner fetch(Operand<?> operand) {
201       return fetch(operand.asOutput());
202     }
203 
204     /**
205      * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
206      * Tensors}.
207      */
addTarget(String operation)208     public Runner addTarget(String operation) {
209       GraphOperation op = operationByName(operation);
210       if (op != null) {
211         targets.add(op);
212       }
213       return this;
214     }
215 
216     /**
217      * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
218      * Tensors}.
219      *
220      * @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
221      */
addTarget(Operation operation)222     public Runner addTarget(Operation operation) {
223       if (!(operation instanceof GraphOperation)) {
224         throw new IllegalArgumentException(
225             "Operation of type "
226                 + operation.getClass().getName()
227                 + " is not supported in graph sessions");
228       }
229       targets.add((GraphOperation) operation);
230       return this;
231     }
232 
233     /**
234      * Make {@link #run} execute {@code operand}, but not return any evaluated {@link Tensor
235      * Tensors}.
236      */
addTarget(Operand<?> operand)237     public Runner addTarget(Operand<?> operand) {
238       return addTarget(operand.asOutput().op());
239     }
240 
241     /**
242      * (Experimental method): set options (typically for debugging) for this run.
243      *
244      * <p>The options are presented as a serialized <a
245      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
246      * protocol buffer</a>.
247      *
248      * <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain
249      * friendly to resource constrained systems (where something like <a
250      * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
251      * be more appropriate). A cost of that is this lack of type-safety in this API function. This
252      * choice is under review and this function may be replaced by more type-safe equivalents at any
253      * time.
254      */
setOptions(byte[] options)255     public Runner setOptions(byte[] options) {
256       this.runOptions = options;
257       return this;
258     }
259 
260     /**
261      * Execute the graph fragments necessary to compute all requested fetches.
262      *
263      * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor Tensors}, i.e.,
264      * the caller must call {@link Tensor#close} on all elements of the returned list to free up
265      * resources.
266      *
267      * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it
268      * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
269      * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
270      * {@code Map<Output, Tensor>}?
271      *
272      * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to
273      * extract output tensors in a type-safe way.
274      */
run()275     public List<Tensor<?>> run() {
276       return runHelper(false).outputs;
277     }
278 
279     /**
280      * Execute graph fragments to compute requested fetches and return metadata about the run.
281      *
282      * <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also
283      * returns metadata about the graph execution in the form of a serialized <a
284      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
285      * protocol buffer</a>.
286      */
runAndFetchMetadata()287     public Run runAndFetchMetadata() {
288       return runHelper(true);
289     }
290 
runHelper(boolean wantMetadata)291     private Run runHelper(boolean wantMetadata) {
292       long[] inputTensorHandles = new long[inputTensors.size()];
293       long[] inputOpHandles = new long[inputs.size()];
294       int[] inputOpIndices = new int[inputs.size()];
295       long[] outputOpHandles = new long[outputs.size()];
296       int[] outputOpIndices = new int[outputs.size()];
297       long[] targetOpHandles = new long[targets.size()];
298       long[] outputTensorHandles = new long[outputs.size()];
299 
300       // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
301       // validity of the Graph and graphRef ensures that.
302       int idx = 0;
303       for (Tensor<?> t : inputTensors) {
304         inputTensorHandles[idx++] = t.getNativeHandle();
305       }
306       idx = 0;
307       for (Output<?> o : inputs) {
308         inputOpHandles[idx] = o.getUnsafeNativeHandle();
309         inputOpIndices[idx] = o.index();
310         idx++;
311       }
312       idx = 0;
313       for (Output<?> o : outputs) {
314         outputOpHandles[idx] = o.getUnsafeNativeHandle();
315         outputOpIndices[idx] = o.index();
316         idx++;
317       }
318       idx = 0;
319       for (GraphOperation op : targets) {
320         targetOpHandles[idx++] = op.getUnsafeNativeHandle();
321       }
322       Reference runRef = new Reference();
323       byte[] metadata = null;
324       try {
325         metadata =
326             Session.run(
327                 nativeHandle,
328                 runOptions,
329                 inputTensorHandles,
330                 inputOpHandles,
331                 inputOpIndices,
332                 outputOpHandles,
333                 outputOpIndices,
334                 targetOpHandles,
335                 wantMetadata,
336                 outputTensorHandles);
337       } finally {
338         runRef.close();
339       }
340       List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
341       for (long h : outputTensorHandles) {
342         try {
343           outputs.add(Tensor.fromHandle(h));
344         } catch (Exception e) {
345           for (Tensor<?> t : outputs) {
346             t.close();
347           }
348           outputs.clear();
349           throw e;
350         }
351       }
352       Run ret = new Run();
353       ret.outputs = outputs;
354       ret.metadata = metadata;
355       return ret;
356     }
357 
358     private class Reference implements AutoCloseable {
Reference()359       public Reference() {
360         synchronized (nativeHandleLock) {
361           if (nativeHandle == 0) {
362             throw new IllegalStateException("run() cannot be called on the Session after close()");
363           }
364           ++numActiveRuns;
365         }
366       }
367 
368       @Override
close()369       public void close() {
370         synchronized (nativeHandleLock) {
371           if (nativeHandle == 0) {
372             return;
373           }
374           if (--numActiveRuns == 0) {
375             nativeHandleLock.notifyAll();
376           }
377         }
378       }
379     }
380 
operationByName(String opName)381     private GraphOperation operationByName(String opName) {
382       GraphOperation op = graph.operation(opName);
383       if (op == null) {
384         throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph");
385       }
386       return op;
387     }
388 
389     @SuppressWarnings("rawtypes")
parseOutput(String opName)390     private Output<?> parseOutput(String opName) {
391       int colon = opName.lastIndexOf(':');
392       if (colon == -1 || colon == opName.length() - 1) {
393         return new Output(operationByName(opName), 0);
394       }
395       try {
396         String op = opName.substring(0, colon);
397         int index = Integer.parseInt(opName.substring(colon + 1));
398         return new Output(operationByName(op), index);
399       } catch (NumberFormatException e) {
400         return new Output(operationByName(opName), 0);
401       }
402     }
403 
404     private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>();
405     private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>();
406     private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>();
407     private ArrayList<GraphOperation> targets = new ArrayList<GraphOperation>();
408     private byte[] runOptions = null;
409   }
410 
411   /** Create a Runner to execute graph operations and evaluate Tensors. */
runner()412   public Runner runner() {
413     return new Runner();
414   }
415 
416   /**
417    * Output tensors and metadata obtained when executing a session.
418    *
419    * <p>See {@link Runner#runAndFetchMetadata()}
420    */
421   public static final class Run {
422     /** Tensors from requested fetches. */
423     public List<Tensor<?>> outputs;
424 
425     /**
426      * (Experimental): Metadata about the run.
427      *
428      * <p>A serialized <a
429      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
430      * protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies
431      * in order to remain friendly to resource constrained systems (where something like <a
432      * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
433      * be more appropriate). A cost of that is this opaque blob. This choice is under review and
434      * this field may be replaced by more type-safe equivalents at any time.
435      */
436     public byte[] metadata;
437   }
438 
439   private final Graph graph;
440   private final Graph.Reference graphRef;
441 
442   private final Object nativeHandleLock = new Object();
443   private long nativeHandle;
444   private int numActiveRuns;
445 
446   // TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2().
allocate(long graphHandle)447   private static native long allocate(long graphHandle);
448 
allocate2(long graphHandle, String target, byte[] config)449   private static native long allocate2(long graphHandle, String target, byte[] config);
450 
delete(long handle)451   private static native void delete(long handle);
452 
453   /**
454    * Execute a session.
455    *
456    * <p>The author apologizes for the ugliness of the long argument list of this method. However,
457    * take solace in the fact that this is a private method meant to cross the JNI boundary.
458    *
459    * @param handle to the C API TF_Session object (Session.nativeHandle)
460    * @param runOptions serialized representation of a RunOptions protocol buffer, or null
461    * @param inputOpHandles (see inputOpIndices)
462    * @param inputOpIndices (see inputTensorHandles)
463    * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values
464    *     that are being "fed" (do not need to be computed) during graph execution.
465    *     inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the
466    *     inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that
467    *     inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
468    * @param outputOpHandles (see outputOpIndices)
469    * @param outputOpIndices together with outputOpHandles identifies the set of values that should
470    *     be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is
471    *     required that outputOpHandles.length == outputOpIndices.length.
472    * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose
473    *     output will not be returned
474    * @param wantRunMetadata indicates whether metadata about this execution should be returned.
475    * @param outputTensorHandles will be filled in with handles to the outputs requested. It is
476    *     required that outputTensorHandles.length == outputOpHandles.length.
477    * @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol
478    *     buffer, false otherwise.
479    */
run( long handle, byte[] runOptions, long[] inputTensorHandles, long[] inputOpHandles, int[] inputOpIndices, long[] outputOpHandles, int[] outputOpIndices, long[] targetOpHandles, boolean wantRunMetadata, long[] outputTensorHandles)480   private static native byte[] run(
481       long handle,
482       byte[] runOptions,
483       long[] inputTensorHandles,
484       long[] inputOpHandles,
485       int[] inputOpIndices,
486       long[] outputOpHandles,
487       int[] outputOpIndices,
488       long[] targetOpHandles,
489       boolean wantRunMetadata,
490       long[] outputTensorHandles);
491 }
492