1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import static com.google.common.truth.Truth.assertThat; 19 import static org.junit.Assert.fail; 20 21 import java.io.File; 22 import java.nio.ByteBuffer; 23 import java.nio.ByteOrder; 24 import java.nio.FloatBuffer; 25 import java.nio.MappedByteBuffer; 26 import java.nio.channels.FileChannel; 27 import java.nio.file.Files; 28 import java.nio.file.Path; 29 import java.nio.file.StandardOpenOption; 30 import java.util.EnumSet; 31 import java.util.HashMap; 32 import java.util.Map; 33 import org.junit.Test; 34 import org.junit.runner.RunWith; 35 import org.junit.runners.JUnit4; 36 37 /** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ 38 @RunWith(JUnit4.class) 39 public final class InterpreterTest { 40 41 private static final File MODEL_FILE = 42 new File("tensorflow/lite/java/src/testdata/add.bin"); 43 44 private static final File MULTIPLE_INPUTS_MODEL_FILE = 45 new File("tensorflow/lite/testdata/multi_add.bin"); 46 47 private static final File FLEX_MODEL_FILE = 48 new File("tensorflow/lite/testdata/multi_add_flex.bin"); 49 50 @Test testInterpreter()51 public void testInterpreter() throws Exception { 52 Interpreter interpreter = new Interpreter(MODEL_FILE); 53 assertThat(interpreter).isNotNull(); 54 assertThat(interpreter.getInputTensorCount()).isEqualTo(1); 55 assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 56 assertThat(interpreter.getOutputTensorCount()).isEqualTo(1); 57 assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 58 interpreter.close(); 59 } 60 61 @Test testInterpreterWithOptions()62 public void testInterpreterWithOptions() throws Exception { 63 Interpreter interpreter = 64 new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true)); 65 assertThat(interpreter).isNotNull(); 66 assertThat(interpreter.getInputTensorCount()).isEqualTo(1); 67 assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 68 assertThat(interpreter.getOutputTensorCount()).isEqualTo(1); 69 assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 70 interpreter.close(); 71 } 72 73 @Test testRunWithMappedByteBufferModel()74 public void testRunWithMappedByteBufferModel() throws Exception { 75 Path path = MODEL_FILE.toPath(); 76 FileChannel fileChannel = 77 (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); 78 ByteBuffer mappedByteBuffer = 79 fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); 80 Interpreter interpreter = new Interpreter(mappedByteBuffer); 81 float[] oneD = {1.23f, 6.54f, 7.81f}; 82 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 83 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 84 float[][][][] fourD = {threeD, threeD}; 85 float[][][][] parsedOutputs = new float[2][8][8][3]; 86 interpreter.run(fourD, parsedOutputs); 87 float[] outputOneD = parsedOutputs[0][0][0]; 88 float[] expected = {3.69f, 19.62f, 23.43f}; 89 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 90 interpreter.close(); 91 fileChannel.close(); 92 } 93 94 @Test testRunWithDirectByteBufferModel()95 public void testRunWithDirectByteBufferModel() throws Exception { 96 Path path = MODEL_FILE.toPath(); 97 FileChannel fileChannel = 98 (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); 99 ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size()); 100 byteBuffer.order(ByteOrder.nativeOrder()); 101 fileChannel.read(byteBuffer); 102 Interpreter interpreter = new Interpreter(byteBuffer); 103 float[] oneD = {1.23f, 6.54f, 7.81f}; 104 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 105 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 106 float[][][][] fourD = {threeD, threeD}; 107 float[][][][] parsedOutputs = new float[2][8][8][3]; 108 interpreter.run(fourD, parsedOutputs); 109 float[] outputOneD = parsedOutputs[0][0][0]; 110 float[] expected = {3.69f, 19.62f, 23.43f}; 111 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 112 interpreter.close(); 113 fileChannel.close(); 114 } 115 116 @Test testRunWithInvalidByteBufferModel()117 public void testRunWithInvalidByteBufferModel() throws Exception { 118 Path path = MODEL_FILE.toPath(); 119 FileChannel fileChannel = 120 (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); 121 ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size()); 122 byteBuffer.order(ByteOrder.nativeOrder()); 123 fileChannel.read(byteBuffer); 124 try { 125 new Interpreter(byteBuffer); 126 fail(); 127 } catch (IllegalArgumentException e) { 128 assertThat(e) 129 .hasMessageThat() 130 .contains( 131 "Model ByteBuffer should be either a MappedByteBuffer" 132 + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()"); 133 } 134 fileChannel.close(); 135 } 136 137 @Test testRun()138 public void testRun() { 139 Interpreter interpreter = new Interpreter(MODEL_FILE); 140 Float[] oneD = {1.23f, 6.54f, 7.81f}; 141 Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 142 Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 143 Float[][][][] fourD = {threeD, threeD}; 144 Float[][][][] parsedOutputs = new Float[2][8][8][3]; 145 try { 146 interpreter.run(fourD, parsedOutputs); 147 fail(); 148 } catch (IllegalArgumentException e) { 149 assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); 150 } 151 interpreter.close(); 152 } 153 154 @Test testRunWithBoxedInputs()155 public void testRunWithBoxedInputs() { 156 Interpreter interpreter = new Interpreter(MODEL_FILE); 157 float[] oneD = {1.23f, 6.54f, 7.81f}; 158 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 159 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 160 float[][][][] fourD = {threeD, threeD}; 161 float[][][][] parsedOutputs = new float[2][8][8][3]; 162 interpreter.run(fourD, parsedOutputs); 163 float[] outputOneD = parsedOutputs[0][0][0]; 164 float[] expected = {3.69f, 19.62f, 23.43f}; 165 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 166 interpreter.close(); 167 } 168 169 @Test testRunForMultipleInputsOutputs()170 public void testRunForMultipleInputsOutputs() { 171 Interpreter interpreter = new Interpreter(MULTIPLE_INPUTS_MODEL_FILE); 172 assertThat(interpreter.getInputTensorCount()).isEqualTo(4); 173 assertThat(interpreter.getInputTensor(0).index()).isGreaterThan(-1); 174 assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 175 assertThat(interpreter.getInputTensor(1).dataType()).isEqualTo(DataType.FLOAT32); 176 assertThat(interpreter.getInputTensor(2).dataType()).isEqualTo(DataType.FLOAT32); 177 assertThat(interpreter.getInputTensor(3).dataType()).isEqualTo(DataType.FLOAT32); 178 assertThat(interpreter.getOutputTensorCount()).isEqualTo(2); 179 assertThat(interpreter.getOutputTensor(0).index()).isGreaterThan(-1); 180 assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); 181 assertThat(interpreter.getOutputTensor(1).dataType()).isEqualTo(DataType.FLOAT32); 182 183 float[] input0 = {1.23f}; 184 float[] input1 = {2.43f}; 185 Object[] inputs = {input0, input1, input0, input1}; 186 float[] parsedOutput0 = new float[1]; 187 float[] parsedOutput1 = new float[1]; 188 Map<Integer, Object> outputs = new HashMap<>(); 189 outputs.put(0, parsedOutput0); 190 outputs.put(1, parsedOutput1); 191 interpreter.runForMultipleInputsOutputs(inputs, outputs); 192 float[] expected0 = {4.89f}; 193 float[] expected1 = {6.09f}; 194 assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder(); 195 assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder(); 196 } 197 198 @Test testRunWithByteBufferOutput()199 public void testRunWithByteBufferOutput() { 200 float[] oneD = {1.23f, 6.54f, 7.81f}; 201 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 202 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 203 float[][][][] fourD = {threeD, threeD}; 204 ByteBuffer parsedOutput = 205 ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); 206 try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { 207 interpreter.run(fourD, parsedOutput); 208 } 209 float[] outputOneD = { 210 parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) 211 }; 212 float[] expected = {3.69f, 19.62f, 23.43f}; 213 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 214 } 215 216 @Test testResizeInput()217 public void testResizeInput() { 218 try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { 219 int[] inputDims = {1}; 220 interpreter.resizeInput(0, inputDims); 221 assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims); 222 ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); 223 ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); 224 interpreter.run(input, output); 225 assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); 226 } 227 } 228 229 @Test testRunWithWrongInputType()230 public void testRunWithWrongInputType() { 231 Interpreter interpreter = new Interpreter(MODEL_FILE); 232 int[] oneD = {4, 3, 9}; 233 int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 234 int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 235 int[][][][] fourD = {threeD, threeD}; 236 float[][][][] parsedOutputs = new float[2][8][8][3]; 237 try { 238 interpreter.run(fourD, parsedOutputs); 239 fail(); 240 } catch (IllegalArgumentException e) { 241 assertThat(e) 242 .hasMessageThat() 243 .contains( 244 "Cannot convert between a TensorFlowLite tensor with type " 245 + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" 246 + " TensorFlowLite type INT32)"); 247 } 248 interpreter.close(); 249 } 250 251 @Test testRunWithUnsupportedInputType()252 public void testRunWithUnsupportedInputType() { 253 FloatBuffer floatBuffer = FloatBuffer.allocate(10); 254 float[][][][] parsedOutputs = new float[2][8][8][3]; 255 try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { 256 interpreter.run(floatBuffer, parsedOutputs); 257 fail(); 258 } catch (IllegalArgumentException e) { 259 assertThat(e).hasMessageThat().contains("DataType error: cannot resolve DataType of"); 260 } 261 } 262 263 @Test testRunWithWrongOutputType()264 public void testRunWithWrongOutputType() { 265 Interpreter interpreter = new Interpreter(MODEL_FILE); 266 float[] oneD = {1.23f, 6.54f, 7.81f}; 267 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 268 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 269 float[][][][] fourD = {threeD, threeD}; 270 int[][][][] parsedOutputs = new int[2][8][8][3]; 271 try { 272 interpreter.run(fourD, parsedOutputs); 273 fail(); 274 } catch (IllegalArgumentException e) { 275 assertThat(e) 276 .hasMessageThat() 277 .contains( 278 "Cannot convert between a TensorFlowLite tensor with type " 279 + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" 280 + " TensorFlowLite type INT32)"); 281 } 282 interpreter.close(); 283 } 284 285 @Test testGetInputIndex()286 public void testGetInputIndex() { 287 Interpreter interpreter = new Interpreter(MODEL_FILE); 288 try { 289 interpreter.getInputIndex("WrongInputName"); 290 fail(); 291 } catch (IllegalArgumentException e) { 292 assertThat(e) 293 .hasMessageThat() 294 .contains( 295 "'WrongInputName' is not a valid name for any input. Names of inputs and their " 296 + "indexes are {input=0}"); 297 } 298 int index = interpreter.getInputIndex("input"); 299 assertThat(index).isEqualTo(0); 300 } 301 302 @Test testGetOutputIndex()303 public void testGetOutputIndex() { 304 Interpreter interpreter = new Interpreter(MODEL_FILE); 305 try { 306 interpreter.getOutputIndex("WrongOutputName"); 307 fail(); 308 } catch (IllegalArgumentException e) { 309 assertThat(e) 310 .hasMessageThat() 311 .contains( 312 "'WrongOutputName' is not a valid name for any output. Names of outputs and their" 313 + " indexes are {output=0}"); 314 } 315 int index = interpreter.getOutputIndex("output"); 316 assertThat(index).isEqualTo(0); 317 } 318 319 @Test testTurnOnNNAPI()320 public void testTurnOnNNAPI() throws Exception { 321 Path path = MODEL_FILE.toPath(); 322 FileChannel fileChannel = 323 (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); 324 MappedByteBuffer mappedByteBuffer = 325 fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); 326 Interpreter interpreter = 327 new Interpreter( 328 mappedByteBuffer, 329 new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true)); 330 float[] oneD = {1.23f, 6.54f, 7.81f}; 331 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 332 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 333 float[][][][] fourD = {threeD, threeD}; 334 float[][][][] parsedOutputs = new float[2][8][8][3]; 335 interpreter.run(fourD, parsedOutputs); 336 float[] outputOneD = parsedOutputs[0][0][0]; 337 float[] expected = {3.69f, 19.62f, 23.43f}; 338 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 339 interpreter.close(); 340 fileChannel.close(); 341 } 342 343 @Test testRedundantClose()344 public void testRedundantClose() throws Exception { 345 Interpreter interpreter = new Interpreter(MODEL_FILE); 346 interpreter.close(); 347 interpreter.close(); 348 } 349 350 @Test testNullInputs()351 public void testNullInputs() throws Exception { 352 Interpreter interpreter = new Interpreter(MODEL_FILE); 353 try { 354 interpreter.run(null, new float[2][8][8][3]); 355 fail(); 356 } catch (IllegalArgumentException e) { 357 // Expected failure. 358 } 359 interpreter.close(); 360 } 361 362 @Test testNullOutputs()363 public void testNullOutputs() throws Exception { 364 Interpreter interpreter = new Interpreter(MODEL_FILE); 365 try { 366 interpreter.run(new float[2][8][8][3], null); 367 fail(); 368 } catch (IllegalArgumentException e) { 369 // Expected failure. 370 } 371 interpreter.close(); 372 } 373 374 /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */ 375 @Test testFlexModel()376 public void testFlexModel() throws Exception { 377 try { 378 new Interpreter(FLEX_MODEL_FILE); 379 fail(); 380 } catch (IllegalStateException e) { 381 // Expected failure. 382 } 383 } 384 385 @Test testDelegate()386 public void testDelegate() throws Exception { 387 System.loadLibrary("tensorflowlite_test_jni"); 388 Delegate delegate = 389 new Delegate() { 390 @Override 391 public long getNativeHandle() { 392 return getNativeHandleForDelegate(); 393 } 394 }; 395 Interpreter interpreter = 396 new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate)); 397 398 // The native delegate stubs out the graph with a single op that produces the scalar value 7. 399 float[] oneD = {1.23f, 6.54f, 7.81f}; 400 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 401 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 402 float[][][][] fourD = {threeD, threeD}; 403 float[][][][] parsedOutputs = new float[2][8][8][3]; 404 interpreter.run(fourD, parsedOutputs); 405 float[] outputOneD = parsedOutputs[0][0][0]; 406 float[] expected = {7.0f, 7.0f, 7.0f}; 407 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 408 409 interpreter.close(); 410 } 411 412 @Test testNullInputsAndOutputsWithDelegate()413 public void testNullInputsAndOutputsWithDelegate() throws Exception { 414 System.loadLibrary("tensorflowlite_test_jni"); 415 Delegate delegate = 416 new Delegate() { 417 @Override 418 public long getNativeHandle() { 419 return getNativeHandleForDelegate(); 420 } 421 }; 422 Interpreter interpreter = 423 new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate)); 424 // The delegate installs a custom buffer handle for all tensors, in turn allowing null to be 425 // provided for the inputs/outputs (as the client can reference the buffer directly). 426 interpreter.run(new float[2][8][8][3], null); 427 interpreter.run(null, new float[2][8][8][3]); 428 interpreter.close(); 429 } 430 431 @Test testModifyGraphWithDelegate()432 public void testModifyGraphWithDelegate() throws Exception { 433 System.loadLibrary("tensorflowlite_test_jni"); 434 Delegate delegate = 435 new Delegate() { 436 @Override 437 public long getNativeHandle() { 438 return getNativeHandleForDelegate(); 439 } 440 }; 441 Interpreter interpreter = new Interpreter(MODEL_FILE); 442 interpreter.modifyGraphWithDelegate(delegate); 443 444 // The native delegate stubs out the graph with a single op that produces the scalar value 7. 445 float[] oneD = {1.23f, 6.54f, 7.81f}; 446 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 447 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 448 float[][][][] fourD = {threeD, threeD}; 449 float[][][][] parsedOutputs = new float[2][8][8][3]; 450 interpreter.run(fourD, parsedOutputs); 451 float[] outputOneD = parsedOutputs[0][0][0]; 452 float[] expected = {7.0f, 7.0f, 7.0f}; 453 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 454 455 interpreter.close(); 456 } 457 458 @Test testInvalidDelegate()459 public void testInvalidDelegate() throws Exception { 460 System.loadLibrary("tensorflowlite_test_jni"); 461 Delegate delegate = 462 new Delegate() { 463 @Override 464 public long getNativeHandle() { 465 return getNativeHandleForInvalidDelegate(); 466 } 467 }; 468 try { 469 Interpreter interpreter = 470 new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate)); 471 fail(); 472 } catch (IllegalArgumentException e) { 473 assertThat(e).hasMessageThat().contains("Internal error: Failed to apply delegate"); 474 } 475 } 476 477 @Test testNullDelegate()478 public void testNullDelegate() throws Exception { 479 System.loadLibrary("tensorflowlite_test_jni"); 480 Delegate delegate = 481 new Delegate() { 482 @Override 483 public long getNativeHandle() { 484 return 0; 485 } 486 }; 487 try { 488 Interpreter interpreter = 489 new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate)); 490 fail(); 491 } catch (IllegalArgumentException e) { 492 assertThat(e).hasMessageThat().contains("Internal error: Invalid handle to delegate"); 493 } 494 } 495 getNativeHandleForDelegate()496 private static native long getNativeHandleForDelegate(); 497 getNativeHandleForInvalidDelegate()498 private static native long getNativeHandleForInvalidDelegate(); 499 } 500