• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.lang.reflect.Array;
19 import java.nio.ByteBuffer;
20 import java.nio.ByteOrder;
21 import java.util.Arrays;
22 
23 /**
24  * A typed multi-dimensional array used in Tensorflow Lite.
25  *
26  * <p>The native handle of a {@code Tensor} is managed by {@code NativeInterpreterWrapper}, and does
27  * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
28  * been closed, the tensor handle will be invalidated.
29  */
30 public final class Tensor {
31 
32   /**
33    * Creates a Tensor wrapper from the provided interpreter instance and tensor index.
34    *
35    * <p>The caller is responsible for closing the created wrapper, and ensuring the provided
36    * native interpreter is valid until the tensor is closed.
37    */
fromIndex(long nativeInterpreterHandle, int tensorIndex)38   static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) {
39     return new Tensor(create(nativeInterpreterHandle, tensorIndex));
40   }
41 
42   /** Disposes of any resources used by the Tensor wrapper. */
close()43   void close() {
44     delete(nativeHandle);
45     nativeHandle = 0;
46   }
47 
48   /** Returns the {@link DataType} of elements stored in the Tensor. */
dataType()49   public DataType dataType() {
50     return dtype;
51   }
52 
53   /**
54    * Returns the number of dimensions (sometimes referred to as <a
55    * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
56    *
57    * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
58    */
numDimensions()59   public int numDimensions() {
60     return shapeCopy.length;
61   }
62 
63   /** Returns the size, in bytes, of the tensor data. */
numBytes()64   public int numBytes() {
65     return numBytes(nativeHandle);
66   }
67 
68   /** Returns the number of elements in a flattened (1-D) view of the tensor. */
numElements()69   public int numElements() {
70     return computeNumElements(shapeCopy);
71   }
72 
73   /**
74    * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
75    * the Tensor, i.e., the sizes of each dimension.
76    *
77    * @return an array where the i-th element is the size of the i-th dimension of the tensor.
78    */
shape()79   public int[] shape() {
80     return shapeCopy;
81   }
82 
83   /**
84    * Returns the (global) index of the tensor within the owning {@link Interpreter}.
85    *
86    * @hide
87    */
index()88   public int index() {
89     return index(nativeHandle);
90   }
91 
92   /**
93    * Copies the contents of the provided {@code src} object to the Tensor.
94    *
95    * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
96    * this tensor, a {@link ByteByffer} of compatible primitive type with a matching flat size, or
97    * {@code null} iff the tensor has an underlying delegate buffer handle.
98    *
99    * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
100    *     with the tensor (for example, mismatched data types or shapes).
101    */
setTo(Object src)102   void setTo(Object src) {
103     if (src == null) {
104       if (hasDelegateBufferHandle(nativeHandle)) {
105         return;
106       }
107       throw new IllegalArgumentException(
108           "Null inputs are allowed only if the Tensor is bound to a buffer handle.");
109     }
110     throwIfDataIsIncompatible(src);
111     if (isByteBuffer(src)) {
112       ByteBuffer srcBuffer = (ByteBuffer) src;
113       // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller
114       // retains ownership of the source buffer until inference has completed.
115       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
116         writeDirectBuffer(nativeHandle, srcBuffer);
117       } else {
118         buffer().put(srcBuffer);
119       }
120       return;
121     }
122     writeMultiDimensionalArray(nativeHandle, src);
123   }
124 
125   /**
126    * Copies the contents of the tensor to {@code dst} and returns {@code dst}.
127    *
128    * @param dst the destination buffer, either an explicitly-typed array, a {@link ByteBuffer} or
129    *     {@code null} iff the tensor has an underlying delegate buffer handle.
130    * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
131    *     mismatched data types or shapes).
132    */
copyTo(Object dst)133   Object copyTo(Object dst) {
134     if (dst == null) {
135       if (hasDelegateBufferHandle(nativeHandle)) {
136         return dst;
137       }
138       throw new IllegalArgumentException(
139           "Null outputs are allowed only if the Tensor is bound to a buffer handle.");
140     }
141     throwIfDataIsIncompatible(dst);
142     if (dst instanceof ByteBuffer) {
143       ByteBuffer dstByteBuffer = (ByteBuffer) dst;
144       dstByteBuffer.put(buffer());
145       return dst;
146     }
147     readMultiDimensionalArray(nativeHandle, dst);
148     return dst;
149   }
150 
151   /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
152   // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
getInputShapeIfDifferent(Object input)153   int[] getInputShapeIfDifferent(Object input) {
154     if (input == null) {
155       return null;
156     }
157     // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
158     // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
159     if (isByteBuffer(input)) {
160       return null;
161     }
162     throwIfTypeIsIncompatible(input);
163     int[] inputShape = computeShapeOf(input);
164     if (Arrays.equals(shapeCopy, inputShape)) {
165       return null;
166     }
167     return inputShape;
168   }
169 
170   /**
171    * Forces a refresh of the tensor's cached shape.
172    *
173    * <p>This is useful if the tensor is resized or has a dynamic shape.
174    */
refreshShape()175   void refreshShape() {
176     this.shapeCopy = shape(nativeHandle);
177   }
178 
179   /** Returns the type of the data. */
dataTypeOf(Object o)180   static DataType dataTypeOf(Object o) {
181     if (o != null) {
182       Class<?> c = o.getClass();
183       while (c.isArray()) {
184         c = c.getComponentType();
185       }
186       if (float.class.equals(c)) {
187         return DataType.FLOAT32;
188       } else if (int.class.equals(c)) {
189         return DataType.INT32;
190       } else if (byte.class.equals(c)) {
191         return DataType.UINT8;
192       } else if (long.class.equals(c)) {
193         return DataType.INT64;
194       } else if (String.class.equals(c)) {
195         return DataType.STRING;
196       }
197     }
198     throw new IllegalArgumentException(
199         "DataType error: cannot resolve DataType of " + o.getClass().getName());
200   }
201 
202   /** Returns the shape of an object as an int array. */
computeShapeOf(Object o)203   static int[] computeShapeOf(Object o) {
204     int size = computeNumDimensions(o);
205     int[] dimensions = new int[size];
206     fillShape(o, 0, dimensions);
207     return dimensions;
208   }
209 
210   /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */
computeNumElements(int[] shape)211   static int computeNumElements(int[] shape) {
212     int n = 1;
213     for (int i = 0; i < shape.length; ++i) {
214       n *= shape[i];
215     }
216     return n;
217   }
218 
219   /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
computeNumDimensions(Object o)220   static int computeNumDimensions(Object o) {
221     if (o == null || !o.getClass().isArray()) {
222       return 0;
223     }
224     if (Array.getLength(o) == 0) {
225       throw new IllegalArgumentException("Array lengths cannot be 0.");
226     }
227     return 1 + computeNumDimensions(Array.get(o, 0));
228   }
229 
230   /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
fillShape(Object o, int dim, int[] shape)231   static void fillShape(Object o, int dim, int[] shape) {
232     if (shape == null || dim == shape.length) {
233       return;
234     }
235     final int len = Array.getLength(o);
236     if (shape[dim] == 0) {
237       shape[dim] = len;
238     } else if (shape[dim] != len) {
239       throw new IllegalArgumentException(
240           String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
241     }
242     for (int i = 0; i < len; ++i) {
243       fillShape(Array.get(o, i), dim + 1, shape);
244     }
245   }
246 
throwIfDataIsIncompatible(Object o)247   private void throwIfDataIsIncompatible(Object o) {
248     throwIfTypeIsIncompatible(o);
249     throwIfShapeIsIncompatible(o);
250   }
251 
throwIfTypeIsIncompatible(Object o)252   private void throwIfTypeIsIncompatible(Object o) {
253     // ByteBuffer payloads can map to any type, so exempt it from the check.
254     if (isByteBuffer(o)) {
255       return;
256     }
257     DataType oType = dataTypeOf(o);
258     if (oType != dtype) {
259       throw new IllegalArgumentException(
260           String.format(
261               "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
262                   + "object of type %s (which is compatible with the TensorFlowLite type %s).",
263               dtype, o.getClass().getName(), oType));
264     }
265   }
266 
throwIfShapeIsIncompatible(Object o)267   private void throwIfShapeIsIncompatible(Object o) {
268     if (isByteBuffer(o)) {
269       ByteBuffer oBuffer = (ByteBuffer) o;
270       if (oBuffer.capacity() != numBytes()) {
271         throw new IllegalArgumentException(
272             String.format(
273                 "Cannot convert between a TensorFlowLite buffer with %d bytes and a "
274                     + "ByteBuffer with %d bytes.",
275                 numBytes(), oBuffer.capacity()));
276       }
277       return;
278     }
279     int[] oShape = computeShapeOf(o);
280     if (!Arrays.equals(oShape, shapeCopy)) {
281       throw new IllegalArgumentException(
282           String.format(
283               "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
284                   + "with shape %s.",
285               Arrays.toString(shapeCopy), Arrays.toString(oShape)));
286     }
287   }
288 
isByteBuffer(Object o)289   private static boolean isByteBuffer(Object o) {
290     return o instanceof ByteBuffer;
291   }
292 
293   private long nativeHandle;
294   private final DataType dtype;
295   private int[] shapeCopy;
296 
Tensor(long nativeHandle)297   private Tensor(long nativeHandle) {
298     this.nativeHandle = nativeHandle;
299     this.dtype = DataType.fromC(dtype(nativeHandle));
300     this.shapeCopy = shape(nativeHandle);
301   }
302 
buffer()303   private ByteBuffer buffer() {
304     return buffer(nativeHandle).order(ByteOrder.nativeOrder());
305   }
306 
create(long interpreterHandle, int tensorIndex)307   private static native long create(long interpreterHandle, int tensorIndex);
308 
delete(long handle)309   private static native void delete(long handle);
310 
buffer(long handle)311   private static native ByteBuffer buffer(long handle);
312 
writeDirectBuffer(long handle, ByteBuffer src)313   private static native void writeDirectBuffer(long handle, ByteBuffer src);
314 
dtype(long handle)315   private static native int dtype(long handle);
316 
shape(long handle)317   private static native int[] shape(long handle);
318 
numBytes(long handle)319   private static native int numBytes(long handle);
320 
hasDelegateBufferHandle(long handle)321   private static native boolean hasDelegateBufferHandle(long handle);
322 
readMultiDimensionalArray(long handle, Object dst)323   private static native void readMultiDimensionalArray(long handle, Object dst);
324 
writeMultiDimensionalArray(long handle, Object src)325   private static native void writeMultiDimensionalArray(long handle, Object src);
326 
index(long handle)327   private static native int index(long handle);
328 
329   static {
TensorFlowLite.init()330     TensorFlowLite.init();
331   }
332 }
333