• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import static com.google.common.truth.Truth.assertThat;
19 import static org.junit.Assert.fail;
20 
21 import java.io.File;
22 import java.nio.ByteBuffer;
23 import java.nio.ByteOrder;
24 import java.nio.FloatBuffer;
25 import java.nio.MappedByteBuffer;
26 import java.nio.channels.FileChannel;
27 import java.nio.file.Files;
28 import java.nio.file.Path;
29 import java.nio.file.StandardOpenOption;
30 import java.util.EnumSet;
31 import java.util.HashMap;
32 import java.util.Map;
33 import org.junit.Test;
34 import org.junit.runner.RunWith;
35 import org.junit.runners.JUnit4;
36 
37 /** Unit tests for {@link org.tensorflow.lite.Interpreter}. */
38 @RunWith(JUnit4.class)
39 public final class InterpreterTest {
40 
41   private static final File MODEL_FILE =
42       new File("tensorflow/lite/java/src/testdata/add.bin");
43 
44   private static final File MULTIPLE_INPUTS_MODEL_FILE =
45       new File("tensorflow/lite/testdata/multi_add.bin");
46 
47   private static final File FLEX_MODEL_FILE =
48       new File("tensorflow/lite/testdata/multi_add_flex.bin");
49 
50   @Test
testInterpreter()51   public void testInterpreter() throws Exception {
52     Interpreter interpreter = new Interpreter(MODEL_FILE);
53     assertThat(interpreter).isNotNull();
54     assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
55     assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
56     assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
57     assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
58     interpreter.close();
59   }
60 
61   @Test
testInterpreterWithOptions()62   public void testInterpreterWithOptions() throws Exception {
63     Interpreter interpreter =
64         new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
65     assertThat(interpreter).isNotNull();
66     assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
67     assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
68     assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
69     assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
70     interpreter.close();
71   }
72 
73   @Test
testRunWithMappedByteBufferModel()74   public void testRunWithMappedByteBufferModel() throws Exception {
75     Path path = MODEL_FILE.toPath();
76     FileChannel fileChannel =
77         (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
78     ByteBuffer mappedByteBuffer =
79         fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
80     Interpreter interpreter = new Interpreter(mappedByteBuffer);
81     float[] oneD = {1.23f, 6.54f, 7.81f};
82     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
83     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
84     float[][][][] fourD = {threeD, threeD};
85     float[][][][] parsedOutputs = new float[2][8][8][3];
86     interpreter.run(fourD, parsedOutputs);
87     float[] outputOneD = parsedOutputs[0][0][0];
88     float[] expected = {3.69f, 19.62f, 23.43f};
89     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
90     interpreter.close();
91     fileChannel.close();
92   }
93 
94   @Test
testRunWithDirectByteBufferModel()95   public void testRunWithDirectByteBufferModel() throws Exception {
96     Path path = MODEL_FILE.toPath();
97     FileChannel fileChannel =
98         (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
99     ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size());
100     byteBuffer.order(ByteOrder.nativeOrder());
101     fileChannel.read(byteBuffer);
102     Interpreter interpreter = new Interpreter(byteBuffer);
103     float[] oneD = {1.23f, 6.54f, 7.81f};
104     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
105     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
106     float[][][][] fourD = {threeD, threeD};
107     float[][][][] parsedOutputs = new float[2][8][8][3];
108     interpreter.run(fourD, parsedOutputs);
109     float[] outputOneD = parsedOutputs[0][0][0];
110     float[] expected = {3.69f, 19.62f, 23.43f};
111     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
112     interpreter.close();
113     fileChannel.close();
114   }
115 
116   @Test
testRunWithInvalidByteBufferModel()117   public void testRunWithInvalidByteBufferModel() throws Exception {
118     Path path = MODEL_FILE.toPath();
119     FileChannel fileChannel =
120         (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
121     ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size());
122     byteBuffer.order(ByteOrder.nativeOrder());
123     fileChannel.read(byteBuffer);
124     try {
125       new Interpreter(byteBuffer);
126       fail();
127     } catch (IllegalArgumentException e) {
128       assertThat(e)
129           .hasMessageThat()
130           .contains(
131               "Model ByteBuffer should be either a MappedByteBuffer"
132                   + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()");
133     }
134     fileChannel.close();
135   }
136 
137   @Test
testRun()138   public void testRun() {
139     Interpreter interpreter = new Interpreter(MODEL_FILE);
140     Float[] oneD = {1.23f, 6.54f, 7.81f};
141     Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
142     Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
143     Float[][][][] fourD = {threeD, threeD};
144     Float[][][][] parsedOutputs = new Float[2][8][8][3];
145     try {
146       interpreter.run(fourD, parsedOutputs);
147       fail();
148     } catch (IllegalArgumentException e) {
149       assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;");
150     }
151     interpreter.close();
152   }
153 
154   @Test
testRunWithBoxedInputs()155   public void testRunWithBoxedInputs() {
156     Interpreter interpreter = new Interpreter(MODEL_FILE);
157     float[] oneD = {1.23f, 6.54f, 7.81f};
158     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
159     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
160     float[][][][] fourD = {threeD, threeD};
161     float[][][][] parsedOutputs = new float[2][8][8][3];
162     interpreter.run(fourD, parsedOutputs);
163     float[] outputOneD = parsedOutputs[0][0][0];
164     float[] expected = {3.69f, 19.62f, 23.43f};
165     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
166     interpreter.close();
167   }
168 
169   @Test
testRunForMultipleInputsOutputs()170   public void testRunForMultipleInputsOutputs() {
171     Interpreter interpreter = new Interpreter(MULTIPLE_INPUTS_MODEL_FILE);
172     assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
173     assertThat(interpreter.getInputTensor(0).index()).isGreaterThan(-1);
174     assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
175     assertThat(interpreter.getInputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
176     assertThat(interpreter.getInputTensor(2).dataType()).isEqualTo(DataType.FLOAT32);
177     assertThat(interpreter.getInputTensor(3).dataType()).isEqualTo(DataType.FLOAT32);
178     assertThat(interpreter.getOutputTensorCount()).isEqualTo(2);
179     assertThat(interpreter.getOutputTensor(0).index()).isGreaterThan(-1);
180     assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
181     assertThat(interpreter.getOutputTensor(1).dataType()).isEqualTo(DataType.FLOAT32);
182 
183     float[] input0 = {1.23f};
184     float[] input1 = {2.43f};
185     Object[] inputs = {input0, input1, input0, input1};
186     float[] parsedOutput0 = new float[1];
187     float[] parsedOutput1 = new float[1];
188     Map<Integer, Object> outputs = new HashMap<>();
189     outputs.put(0, parsedOutput0);
190     outputs.put(1, parsedOutput1);
191     interpreter.runForMultipleInputsOutputs(inputs, outputs);
192     float[] expected0 = {4.89f};
193     float[] expected1 = {6.09f};
194     assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder();
195     assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
196   }
197 
198   @Test
testRunWithByteBufferOutput()199   public void testRunWithByteBufferOutput() {
200     float[] oneD = {1.23f, 6.54f, 7.81f};
201     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
202     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
203     float[][][][] fourD = {threeD, threeD};
204     ByteBuffer parsedOutput =
205         ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
206     try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
207       interpreter.run(fourD, parsedOutput);
208     }
209     float[] outputOneD = {
210       parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
211     };
212     float[] expected = {3.69f, 19.62f, 23.43f};
213     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
214   }
215 
216   @Test
testResizeInput()217   public void testResizeInput() {
218     try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
219       int[] inputDims = {1};
220       interpreter.resizeInput(0, inputDims);
221       assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
222       ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
223       ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
224       interpreter.run(input, output);
225       assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
226     }
227   }
228 
229   @Test
testRunWithWrongInputType()230   public void testRunWithWrongInputType() {
231     Interpreter interpreter = new Interpreter(MODEL_FILE);
232     int[] oneD = {4, 3, 9};
233     int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
234     int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
235     int[][][][] fourD = {threeD, threeD};
236     float[][][][] parsedOutputs = new float[2][8][8][3];
237     try {
238       interpreter.run(fourD, parsedOutputs);
239       fail();
240     } catch (IllegalArgumentException e) {
241       assertThat(e)
242           .hasMessageThat()
243           .contains(
244               "Cannot convert between a TensorFlowLite tensor with type "
245                   + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
246                   + " TensorFlowLite type INT32)");
247     }
248     interpreter.close();
249   }
250 
251   @Test
testRunWithUnsupportedInputType()252   public void testRunWithUnsupportedInputType() {
253     FloatBuffer floatBuffer = FloatBuffer.allocate(10);
254     float[][][][] parsedOutputs = new float[2][8][8][3];
255     try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
256       interpreter.run(floatBuffer, parsedOutputs);
257       fail();
258     } catch (IllegalArgumentException e) {
259       assertThat(e).hasMessageThat().contains("DataType error: cannot resolve DataType of");
260     }
261   }
262 
263   @Test
testRunWithWrongOutputType()264   public void testRunWithWrongOutputType() {
265     Interpreter interpreter = new Interpreter(MODEL_FILE);
266     float[] oneD = {1.23f, 6.54f, 7.81f};
267     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
268     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
269     float[][][][] fourD = {threeD, threeD};
270     int[][][][] parsedOutputs = new int[2][8][8][3];
271     try {
272       interpreter.run(fourD, parsedOutputs);
273       fail();
274     } catch (IllegalArgumentException e) {
275       assertThat(e)
276           .hasMessageThat()
277           .contains(
278               "Cannot convert between a TensorFlowLite tensor with type "
279                   + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
280                   + " TensorFlowLite type INT32)");
281     }
282     interpreter.close();
283   }
284 
285   @Test
testGetInputIndex()286   public void testGetInputIndex() {
287     Interpreter interpreter = new Interpreter(MODEL_FILE);
288     try {
289       interpreter.getInputIndex("WrongInputName");
290       fail();
291     } catch (IllegalArgumentException e) {
292       assertThat(e)
293           .hasMessageThat()
294           .contains(
295               "'WrongInputName' is not a valid name for any input. Names of inputs and their "
296                   + "indexes are {input=0}");
297     }
298     int index = interpreter.getInputIndex("input");
299     assertThat(index).isEqualTo(0);
300   }
301 
302   @Test
testGetOutputIndex()303   public void testGetOutputIndex() {
304     Interpreter interpreter = new Interpreter(MODEL_FILE);
305     try {
306       interpreter.getOutputIndex("WrongOutputName");
307       fail();
308     } catch (IllegalArgumentException e) {
309       assertThat(e)
310           .hasMessageThat()
311           .contains(
312               "'WrongOutputName' is not a valid name for any output. Names of outputs and their"
313                   + " indexes are {output=0}");
314     }
315     int index = interpreter.getOutputIndex("output");
316     assertThat(index).isEqualTo(0);
317   }
318 
319   @Test
testTurnOnNNAPI()320   public void testTurnOnNNAPI() throws Exception {
321     Path path = MODEL_FILE.toPath();
322     FileChannel fileChannel =
323         (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
324     MappedByteBuffer mappedByteBuffer =
325         fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
326     Interpreter interpreter =
327         new Interpreter(
328             mappedByteBuffer,
329             new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true));
330     float[] oneD = {1.23f, 6.54f, 7.81f};
331     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
332     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
333     float[][][][] fourD = {threeD, threeD};
334     float[][][][] parsedOutputs = new float[2][8][8][3];
335     interpreter.run(fourD, parsedOutputs);
336     float[] outputOneD = parsedOutputs[0][0][0];
337     float[] expected = {3.69f, 19.62f, 23.43f};
338     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
339     interpreter.close();
340     fileChannel.close();
341   }
342 
343   @Test
testRedundantClose()344   public void testRedundantClose() throws Exception {
345     Interpreter interpreter = new Interpreter(MODEL_FILE);
346     interpreter.close();
347     interpreter.close();
348   }
349 
350   @Test
testNullInputs()351   public void testNullInputs() throws Exception {
352     Interpreter interpreter = new Interpreter(MODEL_FILE);
353     try {
354       interpreter.run(null, new float[2][8][8][3]);
355       fail();
356     } catch (IllegalArgumentException e) {
357       // Expected failure.
358     }
359     interpreter.close();
360   }
361 
362   @Test
testNullOutputs()363   public void testNullOutputs() throws Exception {
364     Interpreter interpreter = new Interpreter(MODEL_FILE);
365     try {
366       interpreter.run(new float[2][8][8][3], null);
367       fail();
368     } catch (IllegalArgumentException e) {
369       // Expected failure.
370     }
371     interpreter.close();
372   }
373 
374   /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
375   @Test
testFlexModel()376   public void testFlexModel() throws Exception {
377     try {
378       new Interpreter(FLEX_MODEL_FILE);
379       fail();
380     } catch (IllegalStateException e) {
381       // Expected failure.
382     }
383   }
384 
385   @Test
testDelegate()386   public void testDelegate() throws Exception {
387     System.loadLibrary("tensorflowlite_test_jni");
388     Delegate delegate =
389         new Delegate() {
390           @Override
391           public long getNativeHandle() {
392             return getNativeHandleForDelegate();
393           }
394         };
395     Interpreter interpreter =
396         new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
397 
398     // The native delegate stubs out the graph with a single op that produces the scalar value 7.
399     float[] oneD = {1.23f, 6.54f, 7.81f};
400     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
401     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
402     float[][][][] fourD = {threeD, threeD};
403     float[][][][] parsedOutputs = new float[2][8][8][3];
404     interpreter.run(fourD, parsedOutputs);
405     float[] outputOneD = parsedOutputs[0][0][0];
406     float[] expected = {7.0f, 7.0f, 7.0f};
407     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
408 
409     interpreter.close();
410   }
411 
412   @Test
testNullInputsAndOutputsWithDelegate()413   public void testNullInputsAndOutputsWithDelegate() throws Exception {
414     System.loadLibrary("tensorflowlite_test_jni");
415     Delegate delegate =
416         new Delegate() {
417           @Override
418           public long getNativeHandle() {
419             return getNativeHandleForDelegate();
420           }
421         };
422     Interpreter interpreter =
423         new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
424     // The delegate installs a custom buffer handle for all tensors, in turn allowing null to be
425     // provided for the inputs/outputs (as the client can reference the buffer directly).
426     interpreter.run(new float[2][8][8][3], null);
427     interpreter.run(null, new float[2][8][8][3]);
428     interpreter.close();
429   }
430 
431   @Test
testModifyGraphWithDelegate()432   public void testModifyGraphWithDelegate() throws Exception {
433     System.loadLibrary("tensorflowlite_test_jni");
434     Delegate delegate =
435         new Delegate() {
436           @Override
437           public long getNativeHandle() {
438             return getNativeHandleForDelegate();
439           }
440         };
441     Interpreter interpreter = new Interpreter(MODEL_FILE);
442     interpreter.modifyGraphWithDelegate(delegate);
443 
444     // The native delegate stubs out the graph with a single op that produces the scalar value 7.
445     float[] oneD = {1.23f, 6.54f, 7.81f};
446     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
447     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
448     float[][][][] fourD = {threeD, threeD};
449     float[][][][] parsedOutputs = new float[2][8][8][3];
450     interpreter.run(fourD, parsedOutputs);
451     float[] outputOneD = parsedOutputs[0][0][0];
452     float[] expected = {7.0f, 7.0f, 7.0f};
453     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
454 
455     interpreter.close();
456   }
457 
458   @Test
testInvalidDelegate()459   public void testInvalidDelegate() throws Exception {
460     System.loadLibrary("tensorflowlite_test_jni");
461     Delegate delegate =
462         new Delegate() {
463           @Override
464           public long getNativeHandle() {
465             return getNativeHandleForInvalidDelegate();
466           }
467         };
468     try {
469       Interpreter interpreter =
470           new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
471       fail();
472     } catch (IllegalArgumentException e) {
473       assertThat(e).hasMessageThat().contains("Internal error: Failed to apply delegate");
474     }
475   }
476 
477   @Test
testNullDelegate()478   public void testNullDelegate() throws Exception {
479     System.loadLibrary("tensorflowlite_test_jni");
480     Delegate delegate =
481         new Delegate() {
482           @Override
483           public long getNativeHandle() {
484             return 0;
485           }
486         };
487     try {
488       Interpreter interpreter =
489           new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
490       fail();
491     } catch (IllegalArgumentException e) {
492       assertThat(e).hasMessageThat().contains("Internal error: Invalid handle to delegate");
493     }
494   }
495 
getNativeHandleForDelegate()496   private static native long getNativeHandleForDelegate();
497 
getNativeHandleForInvalidDelegate()498   private static native long getNativeHandleForInvalidDelegate();
499 }
500