1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.lang.reflect.Array; 19 import java.nio.Buffer; 20 import java.nio.ByteBuffer; 21 import java.nio.ByteOrder; 22 import java.nio.FloatBuffer; 23 import java.nio.IntBuffer; 24 import java.nio.LongBuffer; 25 import java.nio.ShortBuffer; 26 import java.util.Arrays; 27 import org.checkerframework.checker.nullness.qual.NonNull; 28 29 /** Implementation of {@link Tensor}. */ 30 // TODO(b/153882978): Add scalar getters similar to TF's Java API. 31 final class TensorImpl implements Tensor { 32 33 /** 34 * Creates a Tensor wrapper from the provided interpreter instance and tensor index. 35 * 36 * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native 37 * interpreter is valid until the tensor is closed. 38 */ fromIndex(long nativeInterpreterHandle, int tensorIndex)39 static TensorImpl fromIndex(long nativeInterpreterHandle, int tensorIndex) { 40 return new TensorImpl(create(nativeInterpreterHandle, tensorIndex, /*subgraphIndex=*/ 0)); 41 } 42 43 /** 44 * Creates a Tensor wrapper for a Signature input. 45 * 46 * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native 47 * SignatureRunner is valid until the tensor is closed. 48 */ fromSignatureInput(long signatureRunnerHandle, String inputName)49 static TensorImpl fromSignatureInput(long signatureRunnerHandle, String inputName) { 50 return new TensorImpl(createSignatureInputTensor(signatureRunnerHandle, inputName)); 51 } 52 53 /** 54 * Creates a Tensor wrapper for a Signature output. 55 * 56 * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native 57 * SignatureRunner is valid until the tensor is closed. 58 */ fromSignatureOutput(long signatureRunnerHandle, String outputName)59 static TensorImpl fromSignatureOutput(long signatureRunnerHandle, String outputName) { 60 return new TensorImpl(createSignatureOutputTensor(signatureRunnerHandle, outputName)); 61 } 62 63 /** Disposes of any resources used by the Tensor wrapper. */ close()64 void close() { 65 delete(nativeHandle); 66 nativeHandle = 0; 67 } 68 69 @Override dataType()70 public DataType dataType() { 71 return dtype; 72 } 73 74 @Override numDimensions()75 public int numDimensions() { 76 return shapeCopy.length; 77 } 78 79 @Override numBytes()80 public int numBytes() { 81 return numBytes(nativeHandle); 82 } 83 84 @Override numElements()85 public int numElements() { 86 return computeNumElements(shapeCopy); 87 } 88 89 @Override shape()90 public int[] shape() { 91 return shapeCopy; 92 } 93 94 @Override shapeSignature()95 public int[] shapeSignature() { 96 return shapeSignatureCopy; 97 } 98 99 @Override index()100 public int index() { 101 return index(nativeHandle); 102 } 103 104 @Override name()105 public String name() { 106 return name(nativeHandle); 107 } 108 109 @Override quantizationParams()110 public QuantizationParams quantizationParams() { 111 return quantizationParamsCopy; 112 } 113 114 @Override asReadOnlyBuffer()115 public ByteBuffer asReadOnlyBuffer() { 116 // Note that the ByteBuffer order is not preserved when duplicated or marked read only, so 117 // we have to repeat the call. 118 return buffer().asReadOnlyBuffer().order(ByteOrder.nativeOrder()); 119 } 120 121 /** 122 * Copies the contents of the provided {@code src} object to the Tensor. 123 * 124 * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of 125 * this tensor, a {@link ByteBuffer} of compatible primitive type with a matching flat size, or 126 * {@code null} iff the tensor has an underlying delegate buffer handle. 127 * 128 * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible 129 * with the tensor (for example, mismatched data types or shapes). 130 */ setTo(Object src)131 void setTo(Object src) { 132 if (src == null) { 133 if (hasDelegateBufferHandle(nativeHandle)) { 134 return; 135 } 136 throw new IllegalArgumentException( 137 "Null inputs are allowed only if the Tensor is bound to a buffer handle."); 138 } 139 throwIfTypeIsIncompatible(src); 140 throwIfSrcShapeIsIncompatible(src); 141 if (isBuffer(src)) { 142 setTo((Buffer) src); 143 } else if (dtype == DataType.STRING && shapeCopy.length == 0) { 144 // Update scalar string input with 1-d byte array. 145 writeScalar(nativeHandle, src); 146 } else if (src.getClass().isArray()) { 147 writeMultiDimensionalArray(nativeHandle, src); 148 } else { 149 writeScalar(nativeHandle, src); 150 } 151 } 152 setTo(Buffer src)153 private void setTo(Buffer src) { 154 // Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers. 155 // There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast. 156 if (src instanceof ByteBuffer) { 157 ByteBuffer srcBuffer = (ByteBuffer) src; 158 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 159 writeDirectBuffer(nativeHandle, src); 160 } else { 161 buffer().put(srcBuffer); 162 } 163 } else if (src instanceof LongBuffer) { 164 LongBuffer srcBuffer = (LongBuffer) src; 165 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 166 writeDirectBuffer(nativeHandle, src); 167 } else { 168 buffer().asLongBuffer().put(srcBuffer); 169 } 170 } else if (src instanceof FloatBuffer) { 171 FloatBuffer srcBuffer = (FloatBuffer) src; 172 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 173 writeDirectBuffer(nativeHandle, src); 174 } else { 175 buffer().asFloatBuffer().put(srcBuffer); 176 } 177 } else if (src instanceof IntBuffer) { 178 IntBuffer srcBuffer = (IntBuffer) src; 179 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 180 writeDirectBuffer(nativeHandle, src); 181 } else { 182 buffer().asIntBuffer().put(srcBuffer); 183 } 184 } else if (src instanceof ShortBuffer) { 185 ShortBuffer srcBuffer = (ShortBuffer) src; 186 if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { 187 writeDirectBuffer(nativeHandle, src); 188 } else { 189 buffer().asShortBuffer().put(srcBuffer); 190 } 191 } else { 192 throw new IllegalArgumentException("Unexpected input buffer type: " + src); 193 } 194 } 195 196 /** 197 * Copies the contents of the tensor to {@code dst}. 198 * 199 * @param dst the destination buffer, either an explicitly-typed array, a compatible {@link 200 * Buffer} or {@code null} iff the tensor has an underlying delegate buffer handle. If 201 * providing a (multi-dimensional) array, its shape must match the tensor shape *exactly*. If 202 * providing a {@link Buffer}, its capacity must be at least as large as the source tensor's 203 * capacity. 204 * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, 205 * mismatched data types or shapes). 206 */ copyTo(Object dst)207 void copyTo(Object dst) { 208 if (dst == null) { 209 if (hasDelegateBufferHandle(nativeHandle)) { 210 return; 211 } 212 throw new IllegalArgumentException( 213 "Null outputs are allowed only if the Tensor is bound to a buffer handle."); 214 } 215 throwIfTypeIsIncompatible(dst); 216 throwIfDstShapeIsIncompatible(dst); 217 if (isBuffer(dst)) { 218 copyTo((Buffer) dst); 219 } else { 220 readMultiDimensionalArray(nativeHandle, dst); 221 } 222 } 223 copyTo(Buffer dst)224 private void copyTo(Buffer dst) { 225 // There is no base Buffer#put() method, so we have to ugly cast. 226 if (dst instanceof ByteBuffer) { 227 ((ByteBuffer) dst).put(buffer()); 228 } else if (dst instanceof FloatBuffer) { 229 ((FloatBuffer) dst).put(buffer().asFloatBuffer()); 230 } else if (dst instanceof LongBuffer) { 231 ((LongBuffer) dst).put(buffer().asLongBuffer()); 232 } else if (dst instanceof IntBuffer) { 233 ((IntBuffer) dst).put(buffer().asIntBuffer()); 234 } else if (dst instanceof ShortBuffer) { 235 ((ShortBuffer) dst).put(buffer().asShortBuffer()); 236 } else { 237 throw new IllegalArgumentException("Unexpected output buffer type: " + dst); 238 } 239 } 240 241 /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ 242 // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. getInputShapeIfDifferent(Object input)243 int[] getInputShapeIfDifferent(Object input) { 244 if (input == null) { 245 return null; 246 } 247 // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. 248 // The Buffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. 249 if (isBuffer(input)) { 250 return null; 251 } 252 throwIfTypeIsIncompatible(input); 253 int[] inputShape = computeShapeOf(input); 254 if (Arrays.equals(shapeCopy, inputShape)) { 255 return null; 256 } 257 return inputShape; 258 } 259 260 /** 261 * Forces a refresh of the tensor's cached shape. 262 * 263 * <p>This is useful if the tensor is resized or has a dynamic shape. 264 */ refreshShape()265 void refreshShape() { 266 this.shapeCopy = shape(nativeHandle); 267 } 268 269 /** Returns the type of the data. */ dataTypeOf(@onNull Object o)270 DataType dataTypeOf(@NonNull Object o) { 271 Class<?> c = o.getClass(); 272 // For arrays, the data elements must be a *primitive* type, e.g., an 273 // array of floats is fine, but not an array of Floats. 274 if (c.isArray()) { 275 while (c.isArray()) { 276 c = c.getComponentType(); 277 } 278 if (float.class.equals(c)) { 279 return DataType.FLOAT32; 280 } else if (int.class.equals(c)) { 281 return DataType.INT32; 282 } else if (short.class.equals(c)) { 283 return DataType.INT16; 284 } else if (byte.class.equals(c)) { 285 // Byte array can be used for storing string tensors, especially for ParseExample op. 286 if (dtype == DataType.STRING) { 287 return DataType.STRING; 288 } 289 return DataType.UINT8; 290 } else if (long.class.equals(c)) { 291 return DataType.INT64; 292 } else if (boolean.class.equals(c)) { 293 return DataType.BOOL; 294 } else if (String.class.equals(c)) { 295 return DataType.STRING; 296 } 297 } else { 298 // For scalars, the type will be boxed. 299 if (Float.class.equals(c) || o instanceof FloatBuffer) { 300 return DataType.FLOAT32; 301 } else if (Integer.class.equals(c) || o instanceof IntBuffer) { 302 return DataType.INT32; 303 } else if (Short.class.equals(c) || o instanceof ShortBuffer) { 304 return DataType.INT16; 305 } else if (Byte.class.equals(c)) { 306 // Note that we don't check for ByteBuffer here; ByteBuffer payloads 307 // are allowed to map to any type, and should be handled earlier 308 // in the input/output processing pipeline. 309 return DataType.UINT8; 310 } else if (Long.class.equals(c) || o instanceof LongBuffer) { 311 return DataType.INT64; 312 } else if (Boolean.class.equals(c)) { 313 return DataType.BOOL; 314 } else if (String.class.equals(c)) { 315 return DataType.STRING; 316 } 317 } 318 throw new IllegalArgumentException( 319 "DataType error: cannot resolve DataType of " + o.getClass().getName()); 320 } 321 322 /** Returns the shape of an object as an int array. */ computeShapeOf(Object o)323 private int[] computeShapeOf(Object o) { 324 int size = computeNumDimensions(o); 325 if (dtype == DataType.STRING) { 326 Class<?> c = o.getClass(); 327 if (c.isArray()) { 328 while (c.isArray()) { 329 c = c.getComponentType(); 330 } 331 // If the given string data is stored in byte streams, the last array dimension should be 332 // treated as a value. 333 if (byte.class.equals(c)) { 334 --size; 335 } 336 } 337 } 338 int[] dimensions = new int[size]; 339 fillShape(o, 0, dimensions); 340 return dimensions; 341 } 342 343 /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */ computeNumElements(int[] shape)344 static int computeNumElements(int[] shape) { 345 int n = 1; 346 for (int j : shape) { 347 n *= j; 348 } 349 return n; 350 } 351 352 /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */ computeNumDimensions(Object o)353 static int computeNumDimensions(Object o) { 354 if (o == null || !o.getClass().isArray()) { 355 return 0; 356 } 357 if (Array.getLength(o) == 0) { 358 throw new IllegalArgumentException("Array lengths cannot be 0."); 359 } 360 return 1 + computeNumDimensions(Array.get(o, 0)); 361 } 362 363 /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */ fillShape(Object o, int dim, int[] shape)364 static void fillShape(Object o, int dim, int[] shape) { 365 if (shape == null || dim == shape.length) { 366 return; 367 } 368 final int len = Array.getLength(o); 369 if (shape[dim] == 0) { 370 shape[dim] = len; 371 } else if (shape[dim] != len) { 372 throw new IllegalArgumentException( 373 String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); 374 } 375 final int nextDim = dim + 1; 376 // Short-circuit the innermost dimension to avoid unnecessary Array.get() reflection overhead. 377 if (nextDim == shape.length) { 378 return; 379 } 380 for (int i = 0; i < len; ++i) { 381 fillShape(Array.get(o, i), nextDim, shape); 382 } 383 } 384 throwIfTypeIsIncompatible(@onNull Object o)385 private void throwIfTypeIsIncompatible(@NonNull Object o) { 386 // ByteBuffer payloads can map to any type, so exempt it from the check. 387 if (isByteBuffer(o)) { 388 return; 389 } 390 DataType oType = dataTypeOf(o); 391 392 if (oType != dtype) { 393 // INT8 and UINT8 have the same string name, "byte" 394 if (DataTypeUtils.toStringName(oType).equals(DataTypeUtils.toStringName(dtype))) { 395 return; 396 } 397 398 throw new IllegalArgumentException( 399 String.format( 400 "Cannot convert between a TensorFlowLite tensor with type %s and a Java " 401 + "object of type %s (which is compatible with the TensorFlowLite type %s).", 402 dtype, o.getClass().getName(), oType)); 403 } 404 } 405 throwIfSrcShapeIsIncompatible(Object src)406 private void throwIfSrcShapeIsIncompatible(Object src) { 407 if (isBuffer(src)) { 408 Buffer srcBuffer = (Buffer) src; 409 int bytes = numBytes(); 410 // Note that we allow the client to provide a ByteBuffer even for non-byte Tensors. 411 // In such cases, we only care that the raw byte capacity matches the tensor byte capacity. 412 int srcBytes = 413 isByteBuffer(src) ? srcBuffer.capacity() : srcBuffer.capacity() * dtype.byteSize(); 414 if (bytes != srcBytes) { 415 throw new IllegalArgumentException( 416 String.format( 417 "Cannot copy to a TensorFlowLite tensor (%s) with %d bytes from a " 418 + "Java Buffer with %d bytes.", 419 name(), bytes, srcBytes)); 420 } 421 return; 422 } 423 int[] srcShape = computeShapeOf(src); 424 if (!Arrays.equals(srcShape, shapeCopy)) { 425 throw new IllegalArgumentException( 426 String.format( 427 "Cannot copy to a TensorFlowLite tensor (%s) with shape %s from a Java object " 428 + "with shape %s.", 429 name(), Arrays.toString(shapeCopy), Arrays.toString(srcShape))); 430 } 431 } 432 throwIfDstShapeIsIncompatible(Object dst)433 private void throwIfDstShapeIsIncompatible(Object dst) { 434 if (isBuffer(dst)) { 435 Buffer dstBuffer = (Buffer) dst; 436 int bytes = numBytes(); 437 // Note that we allow the client to provide a ByteBuffer even for non-byte Tensors. 438 // In such cases, we only care that the raw byte capacity fits the tensor byte capacity. 439 // This is subtly different than Buffer *inputs*, where the size should be exact. 440 int dstBytes = 441 isByteBuffer(dst) ? dstBuffer.capacity() : dstBuffer.capacity() * dtype.byteSize(); 442 if (bytes > dstBytes) { 443 throw new IllegalArgumentException( 444 String.format( 445 "Cannot copy from a TensorFlowLite tensor (%s) with %d bytes to a " 446 + "Java Buffer with %d bytes.", 447 name(), bytes, dstBytes)); 448 } 449 return; 450 } 451 int[] dstShape = computeShapeOf(dst); 452 if (!Arrays.equals(dstShape, shapeCopy)) { 453 throw new IllegalArgumentException( 454 String.format( 455 "Cannot copy from a TensorFlowLite tensor (%s) with shape %s to a Java object " 456 + "with shape %s.", 457 name(), Arrays.toString(shapeCopy), Arrays.toString(dstShape))); 458 } 459 } 460 isBuffer(Object o)461 private static boolean isBuffer(Object o) { 462 return o instanceof Buffer; 463 } 464 isByteBuffer(Object o)465 private static boolean isByteBuffer(Object o) { 466 return o instanceof ByteBuffer; 467 } 468 469 private long nativeHandle; 470 private final DataType dtype; 471 private int[] shapeCopy; 472 private final int[] shapeSignatureCopy; 473 private final QuantizationParams quantizationParamsCopy; 474 TensorImpl(long nativeHandle)475 private TensorImpl(long nativeHandle) { 476 this.nativeHandle = nativeHandle; 477 this.dtype = DataTypeUtils.fromC(dtype(nativeHandle)); 478 this.shapeCopy = shape(nativeHandle); 479 this.shapeSignatureCopy = shapeSignature(nativeHandle); 480 this.quantizationParamsCopy = 481 new QuantizationParams( 482 quantizationScale(nativeHandle), quantizationZeroPoint(nativeHandle)); 483 } 484 buffer()485 private ByteBuffer buffer() { 486 return buffer(nativeHandle).order(ByteOrder.nativeOrder()); 487 } 488 create(long interpreterHandle, int tensorIndex, int subgraphIndex)489 private static native long create(long interpreterHandle, int tensorIndex, int subgraphIndex); 490 createSignatureInputTensor( long signatureRunnerHandle, String inputName)491 private static native long createSignatureInputTensor( 492 long signatureRunnerHandle, String inputName); 493 createSignatureOutputTensor( long signatureRunnerHandle, String outputName)494 private static native long createSignatureOutputTensor( 495 long signatureRunnerHandle, String outputName); 496 delete(long handle)497 private static native void delete(long handle); 498 buffer(long handle)499 private static native ByteBuffer buffer(long handle); 500 writeDirectBuffer(long handle, Buffer src)501 private static native void writeDirectBuffer(long handle, Buffer src); 502 dtype(long handle)503 private static native int dtype(long handle); 504 shape(long handle)505 private static native int[] shape(long handle); 506 shapeSignature(long handle)507 private static native int[] shapeSignature(long handle); 508 numBytes(long handle)509 private static native int numBytes(long handle); 510 hasDelegateBufferHandle(long handle)511 private static native boolean hasDelegateBufferHandle(long handle); 512 readMultiDimensionalArray(long handle, Object dst)513 private static native void readMultiDimensionalArray(long handle, Object dst); 514 writeMultiDimensionalArray(long handle, Object src)515 private static native void writeMultiDimensionalArray(long handle, Object src); 516 writeScalar(long handle, Object src)517 private static native void writeScalar(long handle, Object src); 518 index(long handle)519 private static native int index(long handle); 520 name(long handle)521 private static native String name(long handle); 522 quantizationScale(long handle)523 private static native float quantizationScale(long handle); 524 quantizationZeroPoint(long handle)525 private static native int quantizationZeroPoint(long handle); 526 } 527