1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.contrib.android; 17 18 import android.content.res.AssetManager; 19 import android.os.Build.VERSION; 20 import android.os.Trace; 21 import android.text.TextUtils; 22 import android.util.Log; 23 import java.io.ByteArrayOutputStream; 24 import java.io.FileInputStream; 25 import java.io.IOException; 26 import java.io.InputStream; 27 import java.nio.ByteBuffer; 28 import java.nio.DoubleBuffer; 29 import java.nio.FloatBuffer; 30 import java.nio.IntBuffer; 31 import java.nio.LongBuffer; 32 import java.util.ArrayList; 33 import java.util.List; 34 import org.tensorflow.Graph; 35 import org.tensorflow.Operation; 36 import org.tensorflow.Session; 37 import org.tensorflow.Tensor; 38 import org.tensorflow.TensorFlow; 39 import org.tensorflow.Tensors; 40 import org.tensorflow.types.UInt8; 41 42 /** 43 * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface 44 * for inference. 45 * 46 * <p>See tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowImageClassifier.java for 47 * an example usage. 48 */ 49 public class TensorFlowInferenceInterface { 50 private static final String TAG = "TensorFlowInferenceInterface"; 51 private static final String ASSET_FILE_PREFIX = "file:///android_asset/"; 52 53 /* 54 * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. 55 * 56 * @param assetManager The AssetManager to use to load the model file. 57 * @param model The filepath to the GraphDef proto representing the model. 58 */ TensorFlowInferenceInterface(AssetManager assetManager, String model)59 public TensorFlowInferenceInterface(AssetManager assetManager, String model) { 60 prepareNativeRuntime(); 61 62 this.modelName = model; 63 this.g = new Graph(); 64 this.sess = new Session(g); 65 this.runner = sess.runner(); 66 67 final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX); 68 InputStream is = null; 69 try { 70 String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model; 71 is = assetManager.open(aname); 72 } catch (IOException e) { 73 if (hasAssetPrefix) { 74 throw new RuntimeException("Failed to load model from '" + model + "'", e); 75 } 76 // Perhaps the model file is not an asset but is on disk. 77 try { 78 is = new FileInputStream(model); 79 } catch (IOException e2) { 80 throw new RuntimeException("Failed to load model from '" + model + "'", e); 81 } 82 } 83 84 try { 85 if (VERSION.SDK_INT >= 18) { 86 Trace.beginSection("initializeTensorFlow"); 87 Trace.beginSection("readGraphDef"); 88 } 89 90 // TODO(ashankar): Can we somehow mmap the contents instead of copying them? 91 byte[] graphDef = new byte[is.available()]; 92 final int numBytesRead = is.read(graphDef); 93 if (numBytesRead != graphDef.length) { 94 throw new IOException( 95 "read error: read only " 96 + numBytesRead 97 + " of the graph, expected to read " 98 + graphDef.length); 99 } 100 101 if (VERSION.SDK_INT >= 18) { 102 Trace.endSection(); // readGraphDef. 103 } 104 105 loadGraph(graphDef, g); 106 is.close(); 107 Log.i(TAG, "Successfully loaded model from '" + model + "'"); 108 109 if (VERSION.SDK_INT >= 18) { 110 Trace.endSection(); // initializeTensorFlow. 111 } 112 } catch (IOException e) { 113 throw new RuntimeException("Failed to load model from '" + model + "'", e); 114 } 115 } 116 117 /* 118 * Load a TensorFlow model from provided InputStream. 119 * Note: The InputStream will not be closed after loading model, users need to 120 * close it themselves. 121 * 122 * @param is The InputStream to use to load the model. 123 */ TensorFlowInferenceInterface(InputStream is)124 public TensorFlowInferenceInterface(InputStream is) { 125 prepareNativeRuntime(); 126 127 // modelName is redundant for model loading from input stream, here is for 128 // avoiding error in initialization as modelName is marked final. 129 this.modelName = ""; 130 this.g = new Graph(); 131 this.sess = new Session(g); 132 this.runner = sess.runner(); 133 134 try { 135 if (VERSION.SDK_INT >= 18) { 136 Trace.beginSection("initializeTensorFlow"); 137 Trace.beginSection("readGraphDef"); 138 } 139 140 int baosInitSize = is.available() > 16384 ? is.available() : 16384; 141 ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize); 142 int numBytesRead; 143 byte[] buf = new byte[16384]; 144 while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) { 145 baos.write(buf, 0, numBytesRead); 146 } 147 byte[] graphDef = baos.toByteArray(); 148 149 if (VERSION.SDK_INT >= 18) { 150 Trace.endSection(); // readGraphDef. 151 } 152 153 loadGraph(graphDef, g); 154 Log.i(TAG, "Successfully loaded model from the input stream"); 155 156 if (VERSION.SDK_INT >= 18) { 157 Trace.endSection(); // initializeTensorFlow. 158 } 159 } catch (IOException e) { 160 throw new RuntimeException("Failed to load model from the input stream", e); 161 } 162 } 163 164 /* 165 * Construct a TensorFlowInferenceInterface with provided Graph 166 * 167 * @param g The Graph to use to construct this interface. 168 */ TensorFlowInferenceInterface(Graph g)169 public TensorFlowInferenceInterface(Graph g) { 170 prepareNativeRuntime(); 171 172 // modelName is redundant here, here is for 173 // avoiding error in initialization as modelName is marked final. 174 this.modelName = ""; 175 this.g = g; 176 this.sess = new Session(g); 177 this.runner = sess.runner(); 178 } 179 180 /** 181 * Runs inference between the previously registered input nodes (via feed*) and the requested 182 * output nodes. Output nodes can then be queried with the fetch* methods. 183 * 184 * @param outputNames A list of output nodes which should be filled by the inference pass. 185 */ run(String[] outputNames)186 public void run(String[] outputNames) { 187 run(outputNames, false); 188 } 189 190 /** 191 * Runs inference between the previously registered input nodes (via feed*) and the requested 192 * output nodes. Output nodes can then be queried with the fetch* methods. 193 * 194 * @param outputNames A list of output nodes which should be filled by the inference pass. 195 */ run(String[] outputNames, boolean enableStats)196 public void run(String[] outputNames, boolean enableStats) { 197 run(outputNames, enableStats, new String[] {}); 198 } 199 200 /** An overloaded version of runInference that allows supplying targetNodeNames as well */ run(String[] outputNames, boolean enableStats, String[] targetNodeNames)201 public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { 202 // Release any Tensors from the previous run calls. 203 closeFetches(); 204 205 // Add fetches. 206 for (String o : outputNames) { 207 fetchNames.add(o); 208 TensorId tid = TensorId.parse(o); 209 runner.fetch(tid.name, tid.outputIndex); 210 } 211 212 // Add targets. 213 for (String t : targetNodeNames) { 214 runner.addTarget(t); 215 } 216 217 // Run the session. 218 try { 219 if (enableStats) { 220 Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata(); 221 fetchTensors = r.outputs; 222 223 if (runStats == null) { 224 runStats = new RunStats(); 225 } 226 runStats.add(r.metadata); 227 } else { 228 fetchTensors = runner.run(); 229 } 230 } catch (RuntimeException e) { 231 // Ideally the exception would have been let through, but since this interface predates the 232 // TensorFlow Java API, must return -1. 233 Log.e( 234 TAG, 235 "Failed to run TensorFlow inference with inputs:[" 236 + TextUtils.join(", ", feedNames) 237 + "], outputs:[" 238 + TextUtils.join(", ", fetchNames) 239 + "]"); 240 throw e; 241 } finally { 242 // Always release the feeds (to save resources) and reset the runner, this run is 243 // over. 244 closeFeeds(); 245 runner = sess.runner(); 246 } 247 } 248 249 /** Returns a reference to the Graph describing the computation run during inference. */ graph()250 public Graph graph() { 251 return g; 252 } 253 graphOperation(String operationName)254 public Operation graphOperation(String operationName) { 255 final Operation operation = g.operation(operationName); 256 if (operation == null) { 257 throw new RuntimeException( 258 "Node '" + operationName + "' does not exist in model '" + modelName + "'"); 259 } 260 return operation; 261 } 262 263 /** Returns the last stat summary string if logging is enabled. */ getStatString()264 public String getStatString() { 265 return (runStats == null) ? "" : runStats.summary(); 266 } 267 268 /** 269 * Cleans up the state associated with this Object. 270 * 271 * <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns. 272 */ close()273 public void close() { 274 closeFeeds(); 275 closeFetches(); 276 sess.close(); 277 g.close(); 278 if (runStats != null) { 279 runStats.close(); 280 } 281 runStats = null; 282 } 283 284 @Override finalize()285 protected void finalize() throws Throwable { 286 try { 287 close(); 288 } finally { 289 super.finalize(); 290 } 291 } 292 293 // Methods for taking a native Tensor and filling it with values from Java arrays. 294 295 /** 296 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 297 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 298 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 299 * destination has capacity, the copy is truncated. 300 */ feed(String inputName, boolean[] src, long... dims)301 public void feed(String inputName, boolean[] src, long... dims) { 302 byte[] b = new byte[src.length]; 303 304 for (int i = 0; i < src.length; i++) { 305 b[i] = src[i] ? (byte) 1 : (byte) 0; 306 } 307 308 addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b))); 309 } 310 311 /** 312 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 313 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 314 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 315 * destination has capacity, the copy is truncated. 316 */ feed(String inputName, float[] src, long... dims)317 public void feed(String inputName, float[] src, long... dims) { 318 addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src))); 319 } 320 321 /** 322 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 323 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 324 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 325 * destination has capacity, the copy is truncated. 326 */ feed(String inputName, int[] src, long... dims)327 public void feed(String inputName, int[] src, long... dims) { 328 addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src))); 329 } 330 331 /** 332 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 333 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 334 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 335 * destination has capacity, the copy is truncated. 336 */ feed(String inputName, long[] src, long... dims)337 public void feed(String inputName, long[] src, long... dims) { 338 addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src))); 339 } 340 341 /** 342 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 343 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 344 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 345 * destination has capacity, the copy is truncated. 346 */ feed(String inputName, double[] src, long... dims)347 public void feed(String inputName, double[] src, long... dims) { 348 addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); 349 } 350 351 /** 352 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 353 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 354 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 355 * destination has capacity, the copy is truncated. 356 */ feed(String inputName, byte[] src, long... dims)357 public void feed(String inputName, byte[] src, long... dims) { 358 addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src))); 359 } 360 361 /** 362 * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued 363 * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not 364 * a Java {@code String} (which is a sequence of characters). 365 */ feedString(String inputName, byte[] src)366 public void feedString(String inputName, byte[] src) { 367 addFeed(inputName, Tensors.create(src)); 368 } 369 370 /** 371 * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a 372 * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an 373 * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). 374 */ feedString(String inputName, byte[][] src)375 public void feedString(String inputName, byte[][] src) { 376 addFeed(inputName, Tensors.create(src)); 377 } 378 379 // Methods for taking a native Tensor and filling it with src from Java native IO buffers. 380 381 /** 382 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 383 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 384 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 385 * elements as that of the destination Tensor. If {@link src} has more elements than the 386 * destination has capacity, the copy is truncated. 387 */ feed(String inputName, FloatBuffer src, long... dims)388 public void feed(String inputName, FloatBuffer src, long... dims) { 389 addFeed(inputName, Tensor.create(dims, src)); 390 } 391 392 /** 393 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 394 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 395 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 396 * elements as that of the destination Tensor. If {@link src} has more elements than the 397 * destination has capacity, the copy is truncated. 398 */ feed(String inputName, IntBuffer src, long... dims)399 public void feed(String inputName, IntBuffer src, long... dims) { 400 addFeed(inputName, Tensor.create(dims, src)); 401 } 402 403 /** 404 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 405 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 406 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 407 * elements as that of the destination Tensor. If {@link src} has more elements than the 408 * destination has capacity, the copy is truncated. 409 */ feed(String inputName, LongBuffer src, long... dims)410 public void feed(String inputName, LongBuffer src, long... dims) { 411 addFeed(inputName, Tensor.create(dims, src)); 412 } 413 414 /** 415 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 416 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 417 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 418 * elements as that of the destination Tensor. If {@link src} has more elements than the 419 * destination has capacity, the copy is truncated. 420 */ feed(String inputName, DoubleBuffer src, long... dims)421 public void feed(String inputName, DoubleBuffer src, long... dims) { 422 addFeed(inputName, Tensor.create(dims, src)); 423 } 424 425 /** 426 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 427 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 428 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 429 * elements as that of the destination Tensor. If {@link src} has more elements than the 430 * destination has capacity, the copy is truncated. 431 */ feed(String inputName, ByteBuffer src, long... dims)432 public void feed(String inputName, ByteBuffer src, long... dims) { 433 addFeed(inputName, Tensor.create(UInt8.class, dims, src)); 434 } 435 436 /** 437 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 438 * dst} must have length greater than or equal to that of the source Tensor. This operation will 439 * not affect dst's content past the source Tensor's size. 440 */ fetch(String outputName, float[] dst)441 public void fetch(String outputName, float[] dst) { 442 fetch(outputName, FloatBuffer.wrap(dst)); 443 } 444 445 /** 446 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 447 * dst} must have length greater than or equal to that of the source Tensor. This operation will 448 * not affect dst's content past the source Tensor's size. 449 */ fetch(String outputName, int[] dst)450 public void fetch(String outputName, int[] dst) { 451 fetch(outputName, IntBuffer.wrap(dst)); 452 } 453 454 /** 455 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 456 * dst} must have length greater than or equal to that of the source Tensor. This operation will 457 * not affect dst's content past the source Tensor's size. 458 */ fetch(String outputName, long[] dst)459 public void fetch(String outputName, long[] dst) { 460 fetch(outputName, LongBuffer.wrap(dst)); 461 } 462 463 /** 464 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 465 * dst} must have length greater than or equal to that of the source Tensor. This operation will 466 * not affect dst's content past the source Tensor's size. 467 */ fetch(String outputName, double[] dst)468 public void fetch(String outputName, double[] dst) { 469 fetch(outputName, DoubleBuffer.wrap(dst)); 470 } 471 472 /** 473 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 474 * dst} must have length greater than or equal to that of the source Tensor. This operation will 475 * not affect dst's content past the source Tensor's size. 476 */ fetch(String outputName, byte[] dst)477 public void fetch(String outputName, byte[] dst) { 478 fetch(outputName, ByteBuffer.wrap(dst)); 479 } 480 481 /** 482 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 483 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 484 * or equal to that of the source Tensor. This operation will not affect dst's content past the 485 * source Tensor's size. 486 */ fetch(String outputName, FloatBuffer dst)487 public void fetch(String outputName, FloatBuffer dst) { 488 getTensor(outputName).writeTo(dst); 489 } 490 491 /** 492 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 493 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 494 * or equal to that of the source Tensor. This operation will not affect dst's content past the 495 * source Tensor's size. 496 */ fetch(String outputName, IntBuffer dst)497 public void fetch(String outputName, IntBuffer dst) { 498 getTensor(outputName).writeTo(dst); 499 } 500 501 /** 502 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 503 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 504 * or equal to that of the source Tensor. This operation will not affect dst's content past the 505 * source Tensor's size. 506 */ fetch(String outputName, LongBuffer dst)507 public void fetch(String outputName, LongBuffer dst) { 508 getTensor(outputName).writeTo(dst); 509 } 510 511 /** 512 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 513 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 514 * or equal to that of the source Tensor. This operation will not affect dst's content past the 515 * source Tensor's size. 516 */ fetch(String outputName, DoubleBuffer dst)517 public void fetch(String outputName, DoubleBuffer dst) { 518 getTensor(outputName).writeTo(dst); 519 } 520 521 /** 522 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 523 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 524 * or equal to that of the source Tensor. This operation will not affect dst's content past the 525 * source Tensor's size. 526 */ fetch(String outputName, ByteBuffer dst)527 public void fetch(String outputName, ByteBuffer dst) { 528 getTensor(outputName).writeTo(dst); 529 } 530 prepareNativeRuntime()531 private void prepareNativeRuntime() { 532 Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); 533 try { 534 // Hack to see if the native libraries have been loaded. 535 new RunStats(); 536 Log.i(TAG, "TensorFlow native methods already loaded"); 537 } catch (UnsatisfiedLinkError e1) { 538 Log.i( 539 TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference"); 540 try { 541 System.loadLibrary("tensorflow_inference"); 542 Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"); 543 } catch (UnsatisfiedLinkError e2) { 544 throw new RuntimeException( 545 "Native TF methods not found; check that the correct native" 546 + " libraries are present in the APK: " 547 + e2); 548 } 549 } 550 } 551 loadGraph(byte[] graphDef, Graph g)552 private void loadGraph(byte[] graphDef, Graph g) throws IOException { 553 final long startMs = System.currentTimeMillis(); 554 555 if (VERSION.SDK_INT >= 18) { 556 Trace.beginSection("importGraphDef"); 557 } 558 559 try { 560 g.importGraphDef(graphDef); 561 } catch (IllegalArgumentException e) { 562 throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage()); 563 } 564 565 if (VERSION.SDK_INT >= 18) { 566 Trace.endSection(); // importGraphDef. 567 } 568 569 final long endMs = System.currentTimeMillis(); 570 Log.i( 571 TAG, 572 "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); 573 } 574 addFeed(String inputName, Tensor<?> t)575 private void addFeed(String inputName, Tensor<?> t) { 576 // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. 577 TensorId tid = TensorId.parse(inputName); 578 runner.feed(tid.name, tid.outputIndex, t); 579 feedNames.add(inputName); 580 feedTensors.add(t); 581 } 582 583 private static class TensorId { 584 String name; 585 int outputIndex; 586 587 // Parse output names into a TensorId. 588 // 589 // E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1) parse(String name)590 public static TensorId parse(String name) { 591 TensorId tid = new TensorId(); 592 int colonIndex = name.lastIndexOf(':'); 593 if (colonIndex < 0) { 594 tid.outputIndex = 0; 595 tid.name = name; 596 return tid; 597 } 598 try { 599 tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1)); 600 tid.name = name.substring(0, colonIndex); 601 } catch (NumberFormatException e) { 602 tid.outputIndex = 0; 603 tid.name = name; 604 } 605 return tid; 606 } 607 } 608 getTensor(String outputName)609 private Tensor<?> getTensor(String outputName) { 610 int i = 0; 611 for (String n : fetchNames) { 612 if (n.equals(outputName)) { 613 return fetchTensors.get(i); 614 } 615 ++i; 616 } 617 throw new RuntimeException( 618 "Node '" + outputName + "' was not provided to run(), so it cannot be read"); 619 } 620 closeFeeds()621 private void closeFeeds() { 622 for (Tensor<?> t : feedTensors) { 623 t.close(); 624 } 625 feedTensors.clear(); 626 feedNames.clear(); 627 } 628 closeFetches()629 private void closeFetches() { 630 for (Tensor<?> t : fetchTensors) { 631 t.close(); 632 } 633 fetchTensors.clear(); 634 fetchNames.clear(); 635 } 636 637 // Immutable state. 638 private final String modelName; 639 private final Graph g; 640 private final Session sess; 641 642 // State reset on every call to run. 643 private Session.Runner runner; 644 private List<String> feedNames = new ArrayList<String>(); 645 private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>(); 646 private List<String> fetchNames = new ArrayList<String>(); 647 private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>(); 648 649 // Mutable state. 650 private RunStats runStats; 651 } 652