• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import static com.google.common.truth.Truth.assertThat;
19 import static org.junit.Assert.fail;
20 
21 import java.nio.ByteBuffer;
22 import java.nio.ByteOrder;
23 import java.util.HashMap;
24 import java.util.Map;
25 import org.junit.After;
26 import org.junit.Before;
27 import org.junit.Test;
28 import org.junit.runner.RunWith;
29 import org.junit.runners.JUnit4;
30 
31 /** Unit tests for {@link org.tensorflow.lite.Tensor}. */
32 @RunWith(JUnit4.class)
33 public final class TensorTest {
34 
35   private static final String MODEL_PATH =
36       "tensorflow/lite/java/src/testdata/add.bin";
37 
38   private NativeInterpreterWrapper wrapper;
39   private Tensor tensor;
40 
41   @Before
setUp()42   public void setUp() {
43     wrapper = new NativeInterpreterWrapper(MODEL_PATH);
44     float[] oneD = {1.23f, 6.54f, 7.81f};
45     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
46     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
47     float[][][][] fourD = {threeD, threeD};
48     Object[] inputs = {fourD};
49     Map<Integer, Object> outputs = new HashMap<>();
50     outputs.put(0, new float[2][8][8][3]);
51     wrapper.run(inputs, outputs);
52     tensor = wrapper.getOutputTensor(0);
53     assertThat(tensor.index()).isGreaterThan(-1);
54   }
55 
56   @After
tearDown()57   public void tearDown() {
58     wrapper.close();
59   }
60 
61   @Test
testBasic()62   public void testBasic() throws Exception {
63     assertThat(tensor).isNotNull();
64     int[] expectedShape = {2, 8, 8, 3};
65     assertThat(tensor.shape()).isEqualTo(expectedShape);
66     assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
67     assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
68     assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3);
69     assertThat(tensor.numDimensions()).isEqualTo(4);
70   }
71 
72   @Test
testCopyTo()73   public void testCopyTo() {
74     float[][][][] parsedOutputs = new float[2][8][8][3];
75     tensor.copyTo(parsedOutputs);
76     float[] outputOneD = parsedOutputs[0][0][0];
77     float[] expected = {3.69f, 19.62f, 23.43f};
78     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
79   }
80 
81   @Test
testCopyToNull()82   public void testCopyToNull() {
83     try {
84       tensor.copyTo(null);
85       fail();
86     } catch (IllegalArgumentException e) {
87       // Success.
88     }
89   }
90 
91   @Test
testCopyToByteBuffer()92   public void testCopyToByteBuffer() {
93     ByteBuffer parsedOutput =
94         ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
95     tensor.copyTo(parsedOutput);
96     assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
97     float[] outputOneD = {
98       parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
99     };
100     float[] expected = {3.69f, 19.62f, 23.43f};
101     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
102   }
103 
104   @Test
testCopyToInvalidByteBuffer()105   public void testCopyToInvalidByteBuffer() {
106     ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
107     try {
108       tensor.copyTo(parsedOutput);
109       fail();
110     } catch (IllegalArgumentException e) {
111       // Expected.
112     }
113   }
114 
115   @Test
testCopyToWrongType()116   public void testCopyToWrongType() {
117     int[][][][] parsedOutputs = new int[2][8][8][3];
118     try {
119       tensor.copyTo(parsedOutputs);
120       fail();
121     } catch (IllegalArgumentException e) {
122       assertThat(e)
123           .hasMessageThat()
124           .contains(
125               "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
126                   + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
127     }
128   }
129 
130   @Test
testCopyToWrongShape()131   public void testCopyToWrongShape() {
132     float[][][][] parsedOutputs = new float[1][8][8][3];
133     try {
134       tensor.copyTo(parsedOutputs);
135       fail();
136     } catch (IllegalArgumentException e) {
137       assertThat(e)
138           .hasMessageThat()
139           .contains(
140               "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
141                   + "and a Java object with shape [1, 8, 8, 3].");
142     }
143   }
144 
145   @Test
testSetTo()146   public void testSetTo() {
147     float[][][][] input = new float[2][8][8][3];
148     float[][][][] output = new float[2][8][8][3];
149     ByteBuffer inputByteBuffer =
150         ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
151 
152     input[0][0][0][0] = 2.0f;
153     tensor.setTo(input);
154     tensor.copyTo(output);
155     assertThat(output[0][0][0][0]).isEqualTo(2.0f);
156 
157     inputByteBuffer.putFloat(0, 3.0f);
158     tensor.setTo(inputByteBuffer);
159     tensor.copyTo(output);
160     assertThat(output[0][0][0][0]).isEqualTo(3.0f);
161   }
162 
163   @Test
testSetToNull()164   public void testSetToNull() {
165     try {
166       tensor.setTo(null);
167       fail();
168     } catch (IllegalArgumentException e) {
169       // Success.
170     }
171   }
172 
173   @Test
testSetToInvalidByteBuffer()174   public void testSetToInvalidByteBuffer() {
175     ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
176     try {
177       tensor.setTo(input);
178       fail();
179     } catch (IllegalArgumentException e) {
180       // Success.
181     }
182   }
183 
184   @Test
testGetInputShapeIfDifferent()185   public void testGetInputShapeIfDifferent() {
186     ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
187     assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull();
188 
189     float[][][][] sameShapeInput = new float[2][8][8][3];
190     assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull();
191 
192     float[][][][] differentShapeInput = new float[1][8][8][3];
193     assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
194         .isEqualTo(new int[] {1, 8, 8, 3});
195   }
196 
197   @Test
testDataTypeOf()198   public void testDataTypeOf() {
199     float[] testEmptyArray = {};
200     DataType dataType = Tensor.dataTypeOf(testEmptyArray);
201     assertThat(dataType).isEqualTo(DataType.FLOAT32);
202     float[] testFloatArray = {0.783f, 0.251f};
203     dataType = Tensor.dataTypeOf(testFloatArray);
204     assertThat(dataType).isEqualTo(DataType.FLOAT32);
205     float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
206     dataType = Tensor.dataTypeOf(testMultiDimArray);
207     assertThat(dataType).isEqualTo(DataType.FLOAT32);
208     try {
209       double[] testDoubleArray = {0.783, 0.251};
210       Tensor.dataTypeOf(testDoubleArray);
211       fail();
212     } catch (IllegalArgumentException e) {
213       assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
214     }
215     try {
216       Float[] testBoxedArray = {0.783f, 0.251f};
217       Tensor.dataTypeOf(testBoxedArray);
218       fail();
219     } catch (IllegalArgumentException e) {
220       assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
221     }
222   }
223 
224   @Test
testNumDimensions()225   public void testNumDimensions() {
226     int scalar = 1;
227     assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0);
228     int[][] array = {{2, 4}, {1, 9}};
229     assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2);
230     try {
231       int[] emptyArray = {};
232       Tensor.computeNumDimensions(emptyArray);
233       fail();
234     } catch (IllegalArgumentException e) {
235       assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
236     }
237   }
238 
239   @Test
testNumElements()240   public void testNumElements() {
241     int[] scalarShape = {};
242     assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1);
243     int[] vectorShape = {3};
244     assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3);
245     int[] matrixShape = {3, 4};
246     assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12);
247     int[] degenerateShape = {3, 4, 0};
248     assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0);
249   }
250 
251   @Test
testFillShape()252   public void testFillShape() {
253     int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
254     int num = Tensor.computeNumDimensions(array);
255     int[] shape = new int[num];
256     Tensor.fillShape(array, 0, shape);
257     assertThat(num).isEqualTo(3);
258     assertThat(shape[0]).isEqualTo(2);
259     assertThat(shape[1]).isEqualTo(3);
260     assertThat(shape[2]).isEqualTo(1);
261   }
262 
263   @Test
testUseAfterClose()264   public void testUseAfterClose() {
265     tensor.close();
266     try {
267       tensor.numBytes();
268       fail();
269     } catch (IllegalArgumentException e) {
270       // Expected failure.
271     }
272   }
273 }
274