• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite;
17 
18 import java.lang.reflect.Array;
19 import java.nio.Buffer;
20 import java.nio.ByteBuffer;
21 import java.nio.ByteOrder;
22 import java.nio.FloatBuffer;
23 import java.nio.IntBuffer;
24 import java.nio.LongBuffer;
25 import java.nio.ShortBuffer;
26 import java.util.Arrays;
27 import org.checkerframework.checker.nullness.qual.NonNull;
28 
29 /** Implementation of {@link Tensor}. */
30 // TODO(b/153882978): Add scalar getters similar to TF's Java API.
31 final class TensorImpl implements Tensor {
32 
33   /**
34    * Creates a Tensor wrapper from the provided interpreter instance and tensor index.
35    *
36    * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native
37    * interpreter is valid until the tensor is closed.
38    */
fromIndex(long nativeInterpreterHandle, int tensorIndex)39   static TensorImpl fromIndex(long nativeInterpreterHandle, int tensorIndex) {
40     return new TensorImpl(create(nativeInterpreterHandle, tensorIndex, /*subgraphIndex=*/ 0));
41   }
42 
43   /**
44    * Creates a Tensor wrapper for a Signature input.
45    *
46    * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native
47    * SignatureRunner is valid until the tensor is closed.
48    */
fromSignatureInput(long signatureRunnerHandle, String inputName)49   static TensorImpl fromSignatureInput(long signatureRunnerHandle, String inputName) {
50     return new TensorImpl(createSignatureInputTensor(signatureRunnerHandle, inputName));
51   }
52 
53   /**
54    * Creates a Tensor wrapper for a Signature output.
55    *
56    * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native
57    * SignatureRunner is valid until the tensor is closed.
58    */
fromSignatureOutput(long signatureRunnerHandle, String outputName)59   static TensorImpl fromSignatureOutput(long signatureRunnerHandle, String outputName) {
60     return new TensorImpl(createSignatureOutputTensor(signatureRunnerHandle, outputName));
61   }
62 
63   /** Disposes of any resources used by the Tensor wrapper. */
close()64   void close() {
65     delete(nativeHandle);
66     nativeHandle = 0;
67   }
68 
69   @Override
dataType()70   public DataType dataType() {
71     return dtype;
72   }
73 
74   @Override
numDimensions()75   public int numDimensions() {
76     return shapeCopy.length;
77   }
78 
79   @Override
numBytes()80   public int numBytes() {
81     return numBytes(nativeHandle);
82   }
83 
84   @Override
numElements()85   public int numElements() {
86     return computeNumElements(shapeCopy);
87   }
88 
89   @Override
shape()90   public int[] shape() {
91     return shapeCopy;
92   }
93 
94   @Override
shapeSignature()95   public int[] shapeSignature() {
96     return shapeSignatureCopy;
97   }
98 
99   @Override
index()100   public int index() {
101     return index(nativeHandle);
102   }
103 
104   @Override
name()105   public String name() {
106     return name(nativeHandle);
107   }
108 
109   @Override
quantizationParams()110   public QuantizationParams quantizationParams() {
111     return quantizationParamsCopy;
112   }
113 
114   @Override
asReadOnlyBuffer()115   public ByteBuffer asReadOnlyBuffer() {
116     // Note that the ByteBuffer order is not preserved when duplicated or marked read only, so
117     // we have to repeat the call.
118     return buffer().asReadOnlyBuffer().order(ByteOrder.nativeOrder());
119   }
120 
121   /**
122    * Copies the contents of the provided {@code src} object to the Tensor.
123    *
124    * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
125    * this tensor, a {@link ByteBuffer} of compatible primitive type with a matching flat size, or
126    * {@code null} iff the tensor has an underlying delegate buffer handle.
127    *
128    * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
129    *     with the tensor (for example, mismatched data types or shapes).
130    */
setTo(Object src)131   void setTo(Object src) {
132     if (src == null) {
133       if (hasDelegateBufferHandle(nativeHandle)) {
134         return;
135       }
136       throw new IllegalArgumentException(
137           "Null inputs are allowed only if the Tensor is bound to a buffer handle.");
138     }
139     throwIfTypeIsIncompatible(src);
140     throwIfSrcShapeIsIncompatible(src);
141     if (isBuffer(src)) {
142       setTo((Buffer) src);
143     } else if (dtype == DataType.STRING && shapeCopy.length == 0) {
144       // Update scalar string input with 1-d byte array.
145       writeScalar(nativeHandle, src);
146     } else if (src.getClass().isArray()) {
147       writeMultiDimensionalArray(nativeHandle, src);
148     } else {
149       writeScalar(nativeHandle, src);
150     }
151   }
152 
setTo(Buffer src)153   private void setTo(Buffer src) {
154     // Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers.
155     // There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast.
156     if (src instanceof ByteBuffer) {
157       ByteBuffer srcBuffer = (ByteBuffer) src;
158       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
159         writeDirectBuffer(nativeHandle, src);
160       } else {
161         buffer().put(srcBuffer);
162       }
163     } else if (src instanceof LongBuffer) {
164       LongBuffer srcBuffer = (LongBuffer) src;
165       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
166         writeDirectBuffer(nativeHandle, src);
167       } else {
168         buffer().asLongBuffer().put(srcBuffer);
169       }
170     } else if (src instanceof FloatBuffer) {
171       FloatBuffer srcBuffer = (FloatBuffer) src;
172       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
173         writeDirectBuffer(nativeHandle, src);
174       } else {
175         buffer().asFloatBuffer().put(srcBuffer);
176       }
177     } else if (src instanceof IntBuffer) {
178       IntBuffer srcBuffer = (IntBuffer) src;
179       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
180         writeDirectBuffer(nativeHandle, src);
181       } else {
182         buffer().asIntBuffer().put(srcBuffer);
183       }
184     } else if (src instanceof ShortBuffer) {
185       ShortBuffer srcBuffer = (ShortBuffer) src;
186       if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
187         writeDirectBuffer(nativeHandle, src);
188       } else {
189         buffer().asShortBuffer().put(srcBuffer);
190       }
191     } else {
192       throw new IllegalArgumentException("Unexpected input buffer type: " + src);
193     }
194   }
195 
196   /**
197    * Copies the contents of the tensor to {@code dst}.
198    *
199    * @param dst the destination buffer, either an explicitly-typed array, a compatible {@link
200    *     Buffer} or {@code null} iff the tensor has an underlying delegate buffer handle. If
201    *     providing a (multi-dimensional) array, its shape must match the tensor shape *exactly*. If
202    *     providing a {@link Buffer}, its capacity must be at least as large as the source tensor's
203    *     capacity.
204    * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
205    *     mismatched data types or shapes).
206    */
copyTo(Object dst)207   void copyTo(Object dst) {
208     if (dst == null) {
209       if (hasDelegateBufferHandle(nativeHandle)) {
210         return;
211       }
212       throw new IllegalArgumentException(
213           "Null outputs are allowed only if the Tensor is bound to a buffer handle.");
214     }
215     throwIfTypeIsIncompatible(dst);
216     throwIfDstShapeIsIncompatible(dst);
217     if (isBuffer(dst)) {
218       copyTo((Buffer) dst);
219     } else {
220       readMultiDimensionalArray(nativeHandle, dst);
221     }
222   }
223 
copyTo(Buffer dst)224   private void copyTo(Buffer dst) {
225     // There is no base Buffer#put() method, so we have to ugly cast.
226     if (dst instanceof ByteBuffer) {
227       ((ByteBuffer) dst).put(buffer());
228     } else if (dst instanceof FloatBuffer) {
229       ((FloatBuffer) dst).put(buffer().asFloatBuffer());
230     } else if (dst instanceof LongBuffer) {
231       ((LongBuffer) dst).put(buffer().asLongBuffer());
232     } else if (dst instanceof IntBuffer) {
233       ((IntBuffer) dst).put(buffer().asIntBuffer());
234     } else if (dst instanceof ShortBuffer) {
235       ((ShortBuffer) dst).put(buffer().asShortBuffer());
236     } else {
237       throw new IllegalArgumentException("Unexpected output buffer type: " + dst);
238     }
239   }
240 
241   /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
242   // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
getInputShapeIfDifferent(Object input)243   int[] getInputShapeIfDifferent(Object input) {
244     if (input == null) {
245       return null;
246     }
247     // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
248     // The Buffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
249     if (isBuffer(input)) {
250       return null;
251     }
252     throwIfTypeIsIncompatible(input);
253     int[] inputShape = computeShapeOf(input);
254     if (Arrays.equals(shapeCopy, inputShape)) {
255       return null;
256     }
257     return inputShape;
258   }
259 
260   /**
261    * Forces a refresh of the tensor's cached shape.
262    *
263    * <p>This is useful if the tensor is resized or has a dynamic shape.
264    */
refreshShape()265   void refreshShape() {
266     this.shapeCopy = shape(nativeHandle);
267   }
268 
269   /** Returns the type of the data. */
dataTypeOf(@onNull Object o)270   DataType dataTypeOf(@NonNull Object o) {
271     Class<?> c = o.getClass();
272     // For arrays, the data elements must be a *primitive* type, e.g., an
273     // array of floats is fine, but not an array of Floats.
274     if (c.isArray()) {
275       while (c.isArray()) {
276         c = c.getComponentType();
277       }
278       if (float.class.equals(c)) {
279         return DataType.FLOAT32;
280       } else if (int.class.equals(c)) {
281         return DataType.INT32;
282       } else if (short.class.equals(c)) {
283         return DataType.INT16;
284       } else if (byte.class.equals(c)) {
285         // Byte array can be used for storing string tensors, especially for ParseExample op.
286         if (dtype == DataType.STRING) {
287           return DataType.STRING;
288         }
289         return DataType.UINT8;
290       } else if (long.class.equals(c)) {
291         return DataType.INT64;
292       } else if (boolean.class.equals(c)) {
293         return DataType.BOOL;
294       } else if (String.class.equals(c)) {
295         return DataType.STRING;
296       }
297     } else {
298       // For scalars, the type will be boxed.
299       if (Float.class.equals(c) || o instanceof FloatBuffer) {
300         return DataType.FLOAT32;
301       } else if (Integer.class.equals(c) || o instanceof IntBuffer) {
302         return DataType.INT32;
303       } else if (Short.class.equals(c) || o instanceof ShortBuffer) {
304         return DataType.INT16;
305       } else if (Byte.class.equals(c)) {
306         // Note that we don't check for ByteBuffer here; ByteBuffer payloads
307         // are allowed to map to any type, and should be handled earlier
308         // in the input/output processing pipeline.
309         return DataType.UINT8;
310       } else if (Long.class.equals(c) || o instanceof LongBuffer) {
311         return DataType.INT64;
312       } else if (Boolean.class.equals(c)) {
313         return DataType.BOOL;
314       } else if (String.class.equals(c)) {
315         return DataType.STRING;
316       }
317     }
318     throw new IllegalArgumentException(
319         "DataType error: cannot resolve DataType of " + o.getClass().getName());
320   }
321 
322   /** Returns the shape of an object as an int array. */
computeShapeOf(Object o)323   private int[] computeShapeOf(Object o) {
324     int size = computeNumDimensions(o);
325     if (dtype == DataType.STRING) {
326       Class<?> c = o.getClass();
327       if (c.isArray()) {
328         while (c.isArray()) {
329           c = c.getComponentType();
330         }
331         // If the given string data is stored in byte streams, the last array dimension should be
332         // treated as a value.
333         if (byte.class.equals(c)) {
334           --size;
335         }
336       }
337     }
338     int[] dimensions = new int[size];
339     fillShape(o, 0, dimensions);
340     return dimensions;
341   }
342 
343   /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */
computeNumElements(int[] shape)344   static int computeNumElements(int[] shape) {
345     int n = 1;
346     for (int j : shape) {
347       n *= j;
348     }
349     return n;
350   }
351 
352   /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
computeNumDimensions(Object o)353   static int computeNumDimensions(Object o) {
354     if (o == null || !o.getClass().isArray()) {
355       return 0;
356     }
357     if (Array.getLength(o) == 0) {
358       throw new IllegalArgumentException("Array lengths cannot be 0.");
359     }
360     return 1 + computeNumDimensions(Array.get(o, 0));
361   }
362 
363   /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
fillShape(Object o, int dim, int[] shape)364   static void fillShape(Object o, int dim, int[] shape) {
365     if (shape == null || dim == shape.length) {
366       return;
367     }
368     final int len = Array.getLength(o);
369     if (shape[dim] == 0) {
370       shape[dim] = len;
371     } else if (shape[dim] != len) {
372       throw new IllegalArgumentException(
373           String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
374     }
375     final int nextDim = dim + 1;
376     // Short-circuit the innermost dimension to avoid unnecessary Array.get() reflection overhead.
377     if (nextDim == shape.length) {
378       return;
379     }
380     for (int i = 0; i < len; ++i) {
381       fillShape(Array.get(o, i), nextDim, shape);
382     }
383   }
384 
throwIfTypeIsIncompatible(@onNull Object o)385   private void throwIfTypeIsIncompatible(@NonNull Object o) {
386     // ByteBuffer payloads can map to any type, so exempt it from the check.
387     if (isByteBuffer(o)) {
388       return;
389     }
390     DataType oType = dataTypeOf(o);
391 
392     if (oType != dtype) {
393       // INT8 and UINT8 have the same string name, "byte"
394       if (DataTypeUtils.toStringName(oType).equals(DataTypeUtils.toStringName(dtype))) {
395         return;
396       }
397 
398       throw new IllegalArgumentException(
399           String.format(
400               "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
401                   + "object of type %s (which is compatible with the TensorFlowLite type %s).",
402               dtype, o.getClass().getName(), oType));
403     }
404   }
405 
throwIfSrcShapeIsIncompatible(Object src)406   private void throwIfSrcShapeIsIncompatible(Object src) {
407     if (isBuffer(src)) {
408       Buffer srcBuffer = (Buffer) src;
409       int bytes = numBytes();
410       // Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
411       // In such cases, we only care that the raw byte capacity matches the tensor byte capacity.
412       int srcBytes =
413           isByteBuffer(src) ? srcBuffer.capacity() : srcBuffer.capacity() * dtype.byteSize();
414       if (bytes != srcBytes) {
415         throw new IllegalArgumentException(
416             String.format(
417                 "Cannot copy to a TensorFlowLite tensor (%s) with %d bytes from a "
418                     + "Java Buffer with %d bytes.",
419                 name(), bytes, srcBytes));
420       }
421       return;
422     }
423     int[] srcShape = computeShapeOf(src);
424     if (!Arrays.equals(srcShape, shapeCopy)) {
425       throw new IllegalArgumentException(
426           String.format(
427               "Cannot copy to a TensorFlowLite tensor (%s) with shape %s from a Java object "
428                   + "with shape %s.",
429               name(), Arrays.toString(shapeCopy), Arrays.toString(srcShape)));
430     }
431   }
432 
throwIfDstShapeIsIncompatible(Object dst)433   private void throwIfDstShapeIsIncompatible(Object dst) {
434     if (isBuffer(dst)) {
435       Buffer dstBuffer = (Buffer) dst;
436       int bytes = numBytes();
437       // Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
438       // In such cases, we only care that the raw byte capacity fits the tensor byte capacity.
439       // This is subtly different than Buffer *inputs*, where the size should be exact.
440       int dstBytes =
441           isByteBuffer(dst) ? dstBuffer.capacity() : dstBuffer.capacity() * dtype.byteSize();
442       if (bytes > dstBytes) {
443         throw new IllegalArgumentException(
444             String.format(
445                 "Cannot copy from a TensorFlowLite tensor (%s) with %d bytes to a "
446                     + "Java Buffer with %d bytes.",
447                 name(), bytes, dstBytes));
448       }
449       return;
450     }
451     int[] dstShape = computeShapeOf(dst);
452     if (!Arrays.equals(dstShape, shapeCopy)) {
453       throw new IllegalArgumentException(
454           String.format(
455               "Cannot copy from a TensorFlowLite tensor (%s) with shape %s to a Java object "
456                   + "with shape %s.",
457               name(), Arrays.toString(shapeCopy), Arrays.toString(dstShape)));
458     }
459   }
460 
isBuffer(Object o)461   private static boolean isBuffer(Object o) {
462     return o instanceof Buffer;
463   }
464 
isByteBuffer(Object o)465   private static boolean isByteBuffer(Object o) {
466     return o instanceof ByteBuffer;
467   }
468 
469   private long nativeHandle;
470   private final DataType dtype;
471   private int[] shapeCopy;
472   private final int[] shapeSignatureCopy;
473   private final QuantizationParams quantizationParamsCopy;
474 
TensorImpl(long nativeHandle)475   private TensorImpl(long nativeHandle) {
476     this.nativeHandle = nativeHandle;
477     this.dtype = DataTypeUtils.fromC(dtype(nativeHandle));
478     this.shapeCopy = shape(nativeHandle);
479     this.shapeSignatureCopy = shapeSignature(nativeHandle);
480     this.quantizationParamsCopy =
481         new QuantizationParams(
482             quantizationScale(nativeHandle), quantizationZeroPoint(nativeHandle));
483   }
484 
buffer()485   private ByteBuffer buffer() {
486     return buffer(nativeHandle).order(ByteOrder.nativeOrder());
487   }
488 
create(long interpreterHandle, int tensorIndex, int subgraphIndex)489   private static native long create(long interpreterHandle, int tensorIndex, int subgraphIndex);
490 
createSignatureInputTensor( long signatureRunnerHandle, String inputName)491   private static native long createSignatureInputTensor(
492       long signatureRunnerHandle, String inputName);
493 
createSignatureOutputTensor( long signatureRunnerHandle, String outputName)494   private static native long createSignatureOutputTensor(
495       long signatureRunnerHandle, String outputName);
496 
delete(long handle)497   private static native void delete(long handle);
498 
buffer(long handle)499   private static native ByteBuffer buffer(long handle);
500 
writeDirectBuffer(long handle, Buffer src)501   private static native void writeDirectBuffer(long handle, Buffer src);
502 
dtype(long handle)503   private static native int dtype(long handle);
504 
shape(long handle)505   private static native int[] shape(long handle);
506 
shapeSignature(long handle)507   private static native int[] shapeSignature(long handle);
508 
numBytes(long handle)509   private static native int numBytes(long handle);
510 
hasDelegateBufferHandle(long handle)511   private static native boolean hasDelegateBufferHandle(long handle);
512 
readMultiDimensionalArray(long handle, Object dst)513   private static native void readMultiDimensionalArray(long handle, Object dst);
514 
writeMultiDimensionalArray(long handle, Object src)515   private static native void writeMultiDimensionalArray(long handle, Object src);
516 
writeScalar(long handle, Object src)517   private static native void writeScalar(long handle, Object src);
518 
index(long handle)519   private static native int index(long handle);
520 
name(long handle)521   private static native String name(long handle);
522 
quantizationScale(long handle)523   private static native float quantizationScale(long handle);
524 
quantizationZeroPoint(long handle)525   private static native int quantizationZeroPoint(long handle);
526 }
527