1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import static com.google.common.truth.Truth.assertThat; 19 import static org.junit.Assert.fail; 20 21 import java.nio.ByteBuffer; 22 import java.nio.ByteOrder; 23 import java.util.HashMap; 24 import java.util.Map; 25 import org.junit.After; 26 import org.junit.Before; 27 import org.junit.Test; 28 import org.junit.runner.RunWith; 29 import org.junit.runners.JUnit4; 30 31 /** Unit tests for {@link org.tensorflow.lite.Tensor}. */ 32 @RunWith(JUnit4.class) 33 public final class TensorTest { 34 35 private static final String MODEL_PATH = 36 "tensorflow/lite/java/src/testdata/add.bin"; 37 38 private NativeInterpreterWrapper wrapper; 39 private Tensor tensor; 40 41 @Before setUp()42 public void setUp() { 43 wrapper = new NativeInterpreterWrapper(MODEL_PATH); 44 float[] oneD = {1.23f, 6.54f, 7.81f}; 45 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 46 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 47 float[][][][] fourD = {threeD, threeD}; 48 Object[] inputs = {fourD}; 49 Map<Integer, Object> outputs = new HashMap<>(); 50 outputs.put(0, new float[2][8][8][3]); 51 wrapper.run(inputs, outputs); 52 tensor = wrapper.getOutputTensor(0); 53 assertThat(tensor.index()).isGreaterThan(-1); 54 } 55 56 @After tearDown()57 public void tearDown() { 58 wrapper.close(); 59 } 60 61 @Test testBasic()62 public void testBasic() throws Exception { 63 assertThat(tensor).isNotNull(); 64 int[] expectedShape = {2, 8, 8, 3}; 65 assertThat(tensor.shape()).isEqualTo(expectedShape); 66 assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); 67 assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); 68 assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3); 69 assertThat(tensor.numDimensions()).isEqualTo(4); 70 } 71 72 @Test testCopyTo()73 public void testCopyTo() { 74 float[][][][] parsedOutputs = new float[2][8][8][3]; 75 tensor.copyTo(parsedOutputs); 76 float[] outputOneD = parsedOutputs[0][0][0]; 77 float[] expected = {3.69f, 19.62f, 23.43f}; 78 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 79 } 80 81 @Test testCopyToNull()82 public void testCopyToNull() { 83 try { 84 tensor.copyTo(null); 85 fail(); 86 } catch (IllegalArgumentException e) { 87 // Success. 88 } 89 } 90 91 @Test testCopyToByteBuffer()92 public void testCopyToByteBuffer() { 93 ByteBuffer parsedOutput = 94 ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); 95 tensor.copyTo(parsedOutput); 96 assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4); 97 float[] outputOneD = { 98 parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) 99 }; 100 float[] expected = {3.69f, 19.62f, 23.43f}; 101 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 102 } 103 104 @Test testCopyToInvalidByteBuffer()105 public void testCopyToInvalidByteBuffer() { 106 ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); 107 try { 108 tensor.copyTo(parsedOutput); 109 fail(); 110 } catch (IllegalArgumentException e) { 111 // Expected. 112 } 113 } 114 115 @Test testCopyToWrongType()116 public void testCopyToWrongType() { 117 int[][][][] parsedOutputs = new int[2][8][8][3]; 118 try { 119 tensor.copyTo(parsedOutputs); 120 fail(); 121 } catch (IllegalArgumentException e) { 122 assertThat(e) 123 .hasMessageThat() 124 .contains( 125 "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " 126 + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); 127 } 128 } 129 130 @Test testCopyToWrongShape()131 public void testCopyToWrongShape() { 132 float[][][][] parsedOutputs = new float[1][8][8][3]; 133 try { 134 tensor.copyTo(parsedOutputs); 135 fail(); 136 } catch (IllegalArgumentException e) { 137 assertThat(e) 138 .hasMessageThat() 139 .contains( 140 "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] " 141 + "and a Java object with shape [1, 8, 8, 3]."); 142 } 143 } 144 145 @Test testSetTo()146 public void testSetTo() { 147 float[][][][] input = new float[2][8][8][3]; 148 float[][][][] output = new float[2][8][8][3]; 149 ByteBuffer inputByteBuffer = 150 ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); 151 152 input[0][0][0][0] = 2.0f; 153 tensor.setTo(input); 154 tensor.copyTo(output); 155 assertThat(output[0][0][0][0]).isEqualTo(2.0f); 156 157 inputByteBuffer.putFloat(0, 3.0f); 158 tensor.setTo(inputByteBuffer); 159 tensor.copyTo(output); 160 assertThat(output[0][0][0][0]).isEqualTo(3.0f); 161 } 162 163 @Test testSetToNull()164 public void testSetToNull() { 165 try { 166 tensor.setTo(null); 167 fail(); 168 } catch (IllegalArgumentException e) { 169 // Success. 170 } 171 } 172 173 @Test testSetToInvalidByteBuffer()174 public void testSetToInvalidByteBuffer() { 175 ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); 176 try { 177 tensor.setTo(input); 178 fail(); 179 } catch (IllegalArgumentException e) { 180 // Success. 181 } 182 } 183 184 @Test testGetInputShapeIfDifferent()185 public void testGetInputShapeIfDifferent() { 186 ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); 187 assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull(); 188 189 float[][][][] sameShapeInput = new float[2][8][8][3]; 190 assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull(); 191 192 float[][][][] differentShapeInput = new float[1][8][8][3]; 193 assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) 194 .isEqualTo(new int[] {1, 8, 8, 3}); 195 } 196 197 @Test testDataTypeOf()198 public void testDataTypeOf() { 199 float[] testEmptyArray = {}; 200 DataType dataType = Tensor.dataTypeOf(testEmptyArray); 201 assertThat(dataType).isEqualTo(DataType.FLOAT32); 202 float[] testFloatArray = {0.783f, 0.251f}; 203 dataType = Tensor.dataTypeOf(testFloatArray); 204 assertThat(dataType).isEqualTo(DataType.FLOAT32); 205 float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; 206 dataType = Tensor.dataTypeOf(testMultiDimArray); 207 assertThat(dataType).isEqualTo(DataType.FLOAT32); 208 try { 209 double[] testDoubleArray = {0.783, 0.251}; 210 Tensor.dataTypeOf(testDoubleArray); 211 fail(); 212 } catch (IllegalArgumentException e) { 213 assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); 214 } 215 try { 216 Float[] testBoxedArray = {0.783f, 0.251f}; 217 Tensor.dataTypeOf(testBoxedArray); 218 fail(); 219 } catch (IllegalArgumentException e) { 220 assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); 221 } 222 } 223 224 @Test testNumDimensions()225 public void testNumDimensions() { 226 int scalar = 1; 227 assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0); 228 int[][] array = {{2, 4}, {1, 9}}; 229 assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2); 230 try { 231 int[] emptyArray = {}; 232 Tensor.computeNumDimensions(emptyArray); 233 fail(); 234 } catch (IllegalArgumentException e) { 235 assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); 236 } 237 } 238 239 @Test testNumElements()240 public void testNumElements() { 241 int[] scalarShape = {}; 242 assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1); 243 int[] vectorShape = {3}; 244 assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3); 245 int[] matrixShape = {3, 4}; 246 assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12); 247 int[] degenerateShape = {3, 4, 0}; 248 assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0); 249 } 250 251 @Test testFillShape()252 public void testFillShape() { 253 int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; 254 int num = Tensor.computeNumDimensions(array); 255 int[] shape = new int[num]; 256 Tensor.fillShape(array, 0, shape); 257 assertThat(num).isEqualTo(3); 258 assertThat(shape[0]).isEqualTo(2); 259 assertThat(shape[1]).isEqualTo(3); 260 assertThat(shape[2]).isEqualTo(1); 261 } 262 263 @Test testUseAfterClose()264 public void testUseAfterClose() { 265 tensor.close(); 266 try { 267 tensor.numBytes(); 268 fail(); 269 } catch (IllegalArgumentException e) { 270 // Expected failure. 271 } 272 } 273 } 274