1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.lang.reflect.Array; 19 import java.nio.ByteBuffer; 20 import java.nio.ByteOrder; 21 import java.util.Arrays; 22 23 /** 24 * A typed multi-dimensional array used in Tensorflow Lite. 25 * 26 * <p>The native handle of a {@code Tensor} is managed by {@code NativeInterpreterWrapper}, and does 27 * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has 28 * been closed, the tensor handle will be invalidated. 29 */ 30 public final class Tensor { 31 32 /** 33 * Creates a Tensor wrapper from the provided interpreter instance and tensor index. 34 * 35 * <p>The caller is responsible for closing the created wrapper, and ensuring the provided 36 * native interpreter is valid until the tensor is closed. 37 */ fromIndex(long nativeInterpreterHandle, int tensorIndex)38 static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) { 39 return new Tensor(create(nativeInterpreterHandle, tensorIndex)); 40 } 41 42 /** Disposes of any resources used by the Tensor wrapper. */ close()43 void close() { 44 delete(nativeHandle); 45 nativeHandle = 0; 46 } 47 48 /** Returns the {@link DataType} of elements stored in the Tensor. */ dataType()49 public DataType dataType() { 50 return dtype; 51 } 52 53 /** 54 * Returns the number of dimensions (sometimes referred to as <a 55 * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor. 56 * 57 * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc. 58 */ numDimensions()59 public int numDimensions() { 60 return shapeCopy.length; 61 } 62 63 /** Returns the size, in bytes, of the tensor data. */ numBytes()64 public int numBytes() { 65 return numBytes(nativeHandle); 66 } 67 68 /** Returns the number of elements in a flattened (1-D) view of the tensor. */ numElements()69 public int numElements() { 70 return computeNumElements(shapeCopy); 71 } 72 73 /** 74 * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of 75 * the Tensor, i.e., the sizes of each dimension. 76 * 77 * @return an array where the i-th element is the size of the i-th dimension of the tensor. 78 */ shape()79 public int[] shape() { 80 return shapeCopy; 81 } 82 83 /** 84 * Returns the (global) index of the tensor within the owning {@link Interpreter}. 85 * 86 * @hide 87 */ index()88 public int index() { 89 return index(nativeHandle); 90 } 91 92 /** 93 * Copies the contents of the provided {@code src} object to the Tensor. 94 * 95 * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of 96 * this tensor, a {@link ByteByffer} of compatible primitive type with a matching flat size, or 97 * {@code null} iff the tensor has an underlying delegate buffer handle. 98 * 99 * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible 100 * with the tensor (for example, mismatched data types or shapes). 101 */ setTo(Object src)102 void setTo(Object src) { 103 if (src == null) { 104 if (hasDelegateBufferHandle(nativeHandle)) { 105 return; 106 } 107 throw new IllegalArgumentException( 108 "Null inputs are allowed only if the Tensor is bound to a buffer handle."); 109 } 110 throwIfDataIsIncompatible(src); 111 if (isByteBuffer(src)) { 112 ByteBuffer srcBuffer = (ByteBuffer) src; 113 // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller 114 // retains ownership of the source buffer until inference has completed. 115 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 116 writeDirectBuffer(nativeHandle, srcBuffer); 117 } else { 118 buffer().put(srcBuffer); 119 } 120 return; 121 } 122 writeMultiDimensionalArray(nativeHandle, src); 123 } 124 125 /** 126 * Copies the contents of the tensor to {@code dst} and returns {@code dst}. 127 * 128 * @param dst the destination buffer, either an explicitly-typed array, a {@link ByteBuffer} or 129 * {@code null} iff the tensor has an underlying delegate buffer handle. 130 * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, 131 * mismatched data types or shapes). 132 */ copyTo(Object dst)133 Object copyTo(Object dst) { 134 if (dst == null) { 135 if (hasDelegateBufferHandle(nativeHandle)) { 136 return dst; 137 } 138 throw new IllegalArgumentException( 139 "Null outputs are allowed only if the Tensor is bound to a buffer handle."); 140 } 141 throwIfDataIsIncompatible(dst); 142 if (dst instanceof ByteBuffer) { 143 ByteBuffer dstByteBuffer = (ByteBuffer) dst; 144 dstByteBuffer.put(buffer()); 145 return dst; 146 } 147 readMultiDimensionalArray(nativeHandle, dst); 148 return dst; 149 } 150 151 /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ 152 // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. getInputShapeIfDifferent(Object input)153 int[] getInputShapeIfDifferent(Object input) { 154 if (input == null) { 155 return null; 156 } 157 // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. 158 // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. 159 if (isByteBuffer(input)) { 160 return null; 161 } 162 throwIfTypeIsIncompatible(input); 163 int[] inputShape = computeShapeOf(input); 164 if (Arrays.equals(shapeCopy, inputShape)) { 165 return null; 166 } 167 return inputShape; 168 } 169 170 /** 171 * Forces a refresh of the tensor's cached shape. 172 * 173 * <p>This is useful if the tensor is resized or has a dynamic shape. 174 */ refreshShape()175 void refreshShape() { 176 this.shapeCopy = shape(nativeHandle); 177 } 178 179 /** Returns the type of the data. */ dataTypeOf(Object o)180 static DataType dataTypeOf(Object o) { 181 if (o != null) { 182 Class<?> c = o.getClass(); 183 while (c.isArray()) { 184 c = c.getComponentType(); 185 } 186 if (float.class.equals(c)) { 187 return DataType.FLOAT32; 188 } else if (int.class.equals(c)) { 189 return DataType.INT32; 190 } else if (byte.class.equals(c)) { 191 return DataType.UINT8; 192 } else if (long.class.equals(c)) { 193 return DataType.INT64; 194 } else if (String.class.equals(c)) { 195 return DataType.STRING; 196 } 197 } 198 throw new IllegalArgumentException( 199 "DataType error: cannot resolve DataType of " + o.getClass().getName()); 200 } 201 202 /** Returns the shape of an object as an int array. */ computeShapeOf(Object o)203 static int[] computeShapeOf(Object o) { 204 int size = computeNumDimensions(o); 205 int[] dimensions = new int[size]; 206 fillShape(o, 0, dimensions); 207 return dimensions; 208 } 209 210 /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */ computeNumElements(int[] shape)211 static int computeNumElements(int[] shape) { 212 int n = 1; 213 for (int i = 0; i < shape.length; ++i) { 214 n *= shape[i]; 215 } 216 return n; 217 } 218 219 /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */ computeNumDimensions(Object o)220 static int computeNumDimensions(Object o) { 221 if (o == null || !o.getClass().isArray()) { 222 return 0; 223 } 224 if (Array.getLength(o) == 0) { 225 throw new IllegalArgumentException("Array lengths cannot be 0."); 226 } 227 return 1 + computeNumDimensions(Array.get(o, 0)); 228 } 229 230 /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */ fillShape(Object o, int dim, int[] shape)231 static void fillShape(Object o, int dim, int[] shape) { 232 if (shape == null || dim == shape.length) { 233 return; 234 } 235 final int len = Array.getLength(o); 236 if (shape[dim] == 0) { 237 shape[dim] = len; 238 } else if (shape[dim] != len) { 239 throw new IllegalArgumentException( 240 String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); 241 } 242 for (int i = 0; i < len; ++i) { 243 fillShape(Array.get(o, i), dim + 1, shape); 244 } 245 } 246 throwIfDataIsIncompatible(Object o)247 private void throwIfDataIsIncompatible(Object o) { 248 throwIfTypeIsIncompatible(o); 249 throwIfShapeIsIncompatible(o); 250 } 251 throwIfTypeIsIncompatible(Object o)252 private void throwIfTypeIsIncompatible(Object o) { 253 // ByteBuffer payloads can map to any type, so exempt it from the check. 254 if (isByteBuffer(o)) { 255 return; 256 } 257 DataType oType = dataTypeOf(o); 258 if (oType != dtype) { 259 throw new IllegalArgumentException( 260 String.format( 261 "Cannot convert between a TensorFlowLite tensor with type %s and a Java " 262 + "object of type %s (which is compatible with the TensorFlowLite type %s).", 263 dtype, o.getClass().getName(), oType)); 264 } 265 } 266 throwIfShapeIsIncompatible(Object o)267 private void throwIfShapeIsIncompatible(Object o) { 268 if (isByteBuffer(o)) { 269 ByteBuffer oBuffer = (ByteBuffer) o; 270 if (oBuffer.capacity() != numBytes()) { 271 throw new IllegalArgumentException( 272 String.format( 273 "Cannot convert between a TensorFlowLite buffer with %d bytes and a " 274 + "ByteBuffer with %d bytes.", 275 numBytes(), oBuffer.capacity())); 276 } 277 return; 278 } 279 int[] oShape = computeShapeOf(o); 280 if (!Arrays.equals(oShape, shapeCopy)) { 281 throw new IllegalArgumentException( 282 String.format( 283 "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object " 284 + "with shape %s.", 285 Arrays.toString(shapeCopy), Arrays.toString(oShape))); 286 } 287 } 288 isByteBuffer(Object o)289 private static boolean isByteBuffer(Object o) { 290 return o instanceof ByteBuffer; 291 } 292 293 private long nativeHandle; 294 private final DataType dtype; 295 private int[] shapeCopy; 296 Tensor(long nativeHandle)297 private Tensor(long nativeHandle) { 298 this.nativeHandle = nativeHandle; 299 this.dtype = DataType.fromC(dtype(nativeHandle)); 300 this.shapeCopy = shape(nativeHandle); 301 } 302 buffer()303 private ByteBuffer buffer() { 304 return buffer(nativeHandle).order(ByteOrder.nativeOrder()); 305 } 306 create(long interpreterHandle, int tensorIndex)307 private static native long create(long interpreterHandle, int tensorIndex); 308 delete(long handle)309 private static native void delete(long handle); 310 buffer(long handle)311 private static native ByteBuffer buffer(long handle); 312 writeDirectBuffer(long handle, ByteBuffer src)313 private static native void writeDirectBuffer(long handle, ByteBuffer src); 314 dtype(long handle)315 private static native int dtype(long handle); 316 shape(long handle)317 private static native int[] shape(long handle); 318 numBytes(long handle)319 private static native int numBytes(long handle); 320 hasDelegateBufferHandle(long handle)321 private static native boolean hasDelegateBufferHandle(long handle); 322 readMultiDimensionalArray(long handle, Object dst)323 private static native void readMultiDimensionalArray(long handle, Object dst); 324 writeMultiDimensionalArray(long handle, Object src)325 private static native void writeMultiDimensionalArray(long handle, Object src); 326 index(long handle)327 private static native int index(long handle); 328 329 static { TensorFlowLite.init()330 TensorFlowLite.init(); 331 } 332 } 333