1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package org.pytorch.executorch; 10 11 import com.facebook.jni.HybridData; 12 import com.facebook.jni.annotations.DoNotStrip; 13 import java.nio.Buffer; 14 import java.nio.ByteBuffer; 15 import java.nio.ByteOrder; 16 import java.nio.DoubleBuffer; 17 import java.nio.FloatBuffer; 18 import java.nio.IntBuffer; 19 import java.nio.LongBuffer; 20 import java.util.Arrays; 21 import java.util.Locale; 22 import org.pytorch.executorch.annotations.Experimental; 23 24 /** 25 * Representation of an ExecuTorch Tensor. Behavior is similar to PyTorch's tensor objects. 26 * 27 * <p>Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data} 28 * can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided 29 * to allocate buffers properly. 30 * 31 * <p>To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*} 32 * methods. 33 * 34 * <p>When constructing {@code Tensor} objects with {@code data} as an array, it is not specified 35 * whether this data is copied or retained as a reference so it is recommended not to modify it 36 * after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified 37 * between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects 38 * may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape} 39 * is always copied. 40 * 41 * <p>Warning: These APIs are experimental and subject to change without notice 42 */ 43 @Experimental 44 public abstract class Tensor { 45 private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; 46 private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; 47 private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; 48 private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative"; 49 private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = 50 "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; 51 private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = 52 "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; 53 54 @DoNotStrip final long[] shape; 55 56 private static final int BYTE_SIZE_BYTES = 1; 57 private static final int INT_SIZE_BYTES = 4; 58 private static final int LONG_SIZE_BYTES = 8; 59 private static final int FLOAT_SIZE_BYTES = 4; 60 private static final int DOUBLE_SIZE_BYTES = 8; 61 62 /** 63 * Allocates a new direct {@link ByteBuffer} with native byte order with specified capacity that 64 * can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link 65 * Tensor#fromBlobUnsigned(ByteBuffer, long[])}. 66 * 67 * @param numElements capacity (number of elements) of result buffer. 68 */ allocateByteBuffer(int numElements)69 public static ByteBuffer allocateByteBuffer(int numElements) { 70 return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); 71 } 72 73 /** 74 * Allocates a new direct {@link IntBuffer} with native byte order with specified capacity that 75 * can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. 76 * 77 * @param numElements capacity (number of elements) of result buffer. 78 */ allocateIntBuffer(int numElements)79 public static IntBuffer allocateIntBuffer(int numElements) { 80 return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) 81 .order(ByteOrder.nativeOrder()) 82 .asIntBuffer(); 83 } 84 85 /** 86 * Allocates a new direct {@link FloatBuffer} with native byte order with specified capacity that 87 * can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. 88 * 89 * @param numElements capacity (number of elements) of result buffer. 90 */ allocateFloatBuffer(int numElements)91 public static FloatBuffer allocateFloatBuffer(int numElements) { 92 return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) 93 .order(ByteOrder.nativeOrder()) 94 .asFloatBuffer(); 95 } 96 97 /** 98 * Allocates a new direct {@link LongBuffer} with native byte order with specified capacity that 99 * can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. 100 * 101 * @param numElements capacity (number of elements) of result buffer. 102 */ allocateLongBuffer(int numElements)103 public static LongBuffer allocateLongBuffer(int numElements) { 104 return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) 105 .order(ByteOrder.nativeOrder()) 106 .asLongBuffer(); 107 } 108 109 /** 110 * Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that 111 * can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. 112 * 113 * @param numElements capacity (number of elements) of result buffer. 114 */ allocateDoubleBuffer(int numElements)115 public static DoubleBuffer allocateDoubleBuffer(int numElements) { 116 return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) 117 .order(ByteOrder.nativeOrder()) 118 .asDoubleBuffer(); 119 } 120 121 /** 122 * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of 123 * bytes. 124 * 125 * @param data Tensor elements 126 * @param shape Tensor shape 127 */ fromBlobUnsigned(byte[] data, long[] shape)128 public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { 129 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 130 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 131 checkShape(shape); 132 checkShapeAndDataCapacityConsistency(data.length, shape); 133 final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); 134 byteBuffer.put(data); 135 return new Tensor_uint8(byteBuffer, shape); 136 } 137 138 /** 139 * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of 140 * bytes. 141 * 142 * @param data Tensor elements 143 * @param shape Tensor shape 144 */ fromBlob(byte[] data, long[] shape)145 public static Tensor fromBlob(byte[] data, long[] shape) { 146 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 147 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 148 checkShape(shape); 149 checkShapeAndDataCapacityConsistency(data.length, shape); 150 final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); 151 byteBuffer.put(data); 152 return new Tensor_int8(byteBuffer, shape); 153 } 154 155 /** 156 * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of 157 * ints. 158 * 159 * @param data Tensor elements 160 * @param shape Tensor shape 161 */ fromBlob(int[] data, long[] shape)162 public static Tensor fromBlob(int[] data, long[] shape) { 163 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 164 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 165 checkShape(shape); 166 checkShapeAndDataCapacityConsistency(data.length, shape); 167 final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); 168 intBuffer.put(data); 169 return new Tensor_int32(intBuffer, shape); 170 } 171 172 /** 173 * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array 174 * of floats. 175 * 176 * @param data Tensor elements 177 * @param shape Tensor shape 178 */ fromBlob(float[] data, long[] shape)179 public static Tensor fromBlob(float[] data, long[] shape) { 180 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 181 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 182 checkShape(shape); 183 checkShapeAndDataCapacityConsistency(data.length, shape); 184 final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); 185 floatBuffer.put(data); 186 return new Tensor_float32(floatBuffer, shape); 187 } 188 189 /** 190 * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of 191 * longs. 192 * 193 * @param data Tensor elements 194 * @param shape Tensor shape 195 */ fromBlob(long[] data, long[] shape)196 public static Tensor fromBlob(long[] data, long[] shape) { 197 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 198 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 199 checkShape(shape); 200 checkShapeAndDataCapacityConsistency(data.length, shape); 201 final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); 202 longBuffer.put(data); 203 return new Tensor_int64(longBuffer, shape); 204 } 205 206 /** 207 * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array 208 * of doubles. 209 * 210 * @param shape Tensor shape 211 * @param data Tensor elements 212 */ fromBlob(double[] data, long[] shape)213 public static Tensor fromBlob(double[] data, long[] shape) { 214 checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); 215 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 216 checkShape(shape); 217 checkShapeAndDataCapacityConsistency(data.length, shape); 218 final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); 219 doubleBuffer.put(data); 220 return new Tensor_float64(doubleBuffer, shape); 221 } 222 223 /** 224 * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. 225 * 226 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 227 * elements. The buffer is used directly without copying, and changes to its content will 228 * change the tensor. 229 * @param shape Tensor shape 230 */ fromBlobUnsigned(ByteBuffer data, long[] shape)231 public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { 232 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 233 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 234 checkShape(shape); 235 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 236 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 237 checkArgument( 238 (data.order() == ByteOrder.nativeOrder()), 239 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 240 return new Tensor_uint8(data, shape); 241 } 242 243 /** 244 * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. 245 * 246 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 247 * elements. The buffer is used directly without copying, and changes to its content will 248 * change the tensor. 249 * @param shape Tensor shape 250 */ fromBlob(ByteBuffer data, long[] shape)251 public static Tensor fromBlob(ByteBuffer data, long[] shape) { 252 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 253 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 254 checkShape(shape); 255 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 256 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 257 checkArgument( 258 (data.order() == ByteOrder.nativeOrder()), 259 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 260 return new Tensor_int8(data, shape); 261 } 262 263 /** 264 * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. 265 * 266 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 267 * elements. The buffer is used directly without copying, and changes to its content will 268 * change the tensor. 269 * @param shape Tensor shape 270 */ fromBlob(IntBuffer data, long[] shape)271 public static Tensor fromBlob(IntBuffer data, long[] shape) { 272 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 273 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 274 checkShape(shape); 275 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 276 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 277 checkArgument( 278 (data.order() == ByteOrder.nativeOrder()), 279 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 280 return new Tensor_int32(data, shape); 281 } 282 283 /** 284 * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. 285 * 286 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 287 * elements. The buffer is used directly without copying, and changes to its content will 288 * change the tensor. 289 * @param shape Tensor shape 290 */ fromBlob(FloatBuffer data, long[] shape)291 public static Tensor fromBlob(FloatBuffer data, long[] shape) { 292 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 293 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 294 checkShape(shape); 295 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 296 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 297 checkArgument( 298 (data.order() == ByteOrder.nativeOrder()), 299 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 300 return new Tensor_float32(data, shape); 301 } 302 303 /** 304 * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. 305 * 306 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 307 * elements. The buffer is used directly without copying, and changes to its content will 308 * change the tensor. 309 * @param shape Tensor shape 310 */ fromBlob(LongBuffer data, long[] shape)311 public static Tensor fromBlob(LongBuffer data, long[] shape) { 312 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 313 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 314 checkShape(shape); 315 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 316 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 317 checkArgument( 318 (data.order() == ByteOrder.nativeOrder()), 319 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 320 return new Tensor_int64(data, shape); 321 } 322 323 /** 324 * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. 325 * 326 * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} 327 * elements. The buffer is used directly without copying, and changes to its content will 328 * change the tensor. 329 * @param shape Tensor shape 330 */ fromBlob(DoubleBuffer data, long[] shape)331 public static Tensor fromBlob(DoubleBuffer data, long[] shape) { 332 checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); 333 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 334 checkShape(shape); 335 checkShapeAndDataCapacityConsistency(data.capacity(), shape); 336 checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); 337 checkArgument( 338 (data.order() == ByteOrder.nativeOrder()), 339 ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); 340 return new Tensor_float64(data, shape); 341 } 342 343 @DoNotStrip private HybridData mHybridData; 344 Tensor(long[] shape)345 private Tensor(long[] shape) { 346 checkShape(shape); 347 this.shape = Arrays.copyOf(shape, shape.length); 348 } 349 350 /** Returns the number of elements in this tensor. */ numel()351 public long numel() { 352 return numel(this.shape); 353 } 354 355 /** Calculates the number of elements in a tensor with the specified shape. */ numel(long[] shape)356 public static long numel(long[] shape) { 357 checkShape(shape); 358 int result = 1; 359 for (long s : shape) { 360 result *= s; 361 } 362 return result; 363 } 364 365 /** Returns the shape of this tensor. (The array is a fresh copy.) */ shape()366 public long[] shape() { 367 return Arrays.copyOf(shape, shape.length); 368 } 369 370 /** 371 * @return data type of this tensor. 372 */ dtype()373 public abstract DType dtype(); 374 375 // Called from native 376 @DoNotStrip dtypeJniCode()377 int dtypeJniCode() { 378 return dtype().jniCode; 379 } 380 381 /** 382 * @return a Java byte array that contains the tensor data. This may be a copy or reference. 383 * @throws IllegalStateException if it is called for a non-int8 tensor. 384 */ getDataAsByteArray()385 public byte[] getDataAsByteArray() { 386 throw new IllegalStateException( 387 "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); 388 } 389 390 /** 391 * @return a Java byte array that contains the tensor data. This may be a copy or reference. 392 * @throws IllegalStateException if it is called for a non-uint8 tensor. 393 */ getDataAsUnsignedByteArray()394 public byte[] getDataAsUnsignedByteArray() { 395 throw new IllegalStateException( 396 "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); 397 } 398 399 /** 400 * @return a Java int array that contains the tensor data. This may be a copy or reference. 401 * @throws IllegalStateException if it is called for a non-int32 tensor. 402 */ getDataAsIntArray()403 public int[] getDataAsIntArray() { 404 throw new IllegalStateException( 405 "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array."); 406 } 407 408 /** 409 * @return a Java float array that contains the tensor data. This may be a copy or reference. 410 * @throws IllegalStateException if it is called for a non-float32 tensor. 411 */ getDataAsFloatArray()412 public float[] getDataAsFloatArray() { 413 throw new IllegalStateException( 414 "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); 415 } 416 417 /** 418 * @return a Java long array that contains the tensor data. This may be a copy or reference. 419 * @throws IllegalStateException if it is called for a non-int64 tensor. 420 */ getDataAsLongArray()421 public long[] getDataAsLongArray() { 422 throw new IllegalStateException( 423 "Tensor of type " + getClass().getSimpleName() + " cannot return data as long array."); 424 } 425 426 /** 427 * @return a Java double array that contains the tensor data. This may be a copy or reference. 428 * @throws IllegalStateException if it is called for a non-float64 tensor. 429 */ getDataAsDoubleArray()430 public double[] getDataAsDoubleArray() { 431 throw new IllegalStateException( 432 "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); 433 } 434 435 @DoNotStrip getRawDataBuffer()436 Buffer getRawDataBuffer() { 437 throw new IllegalStateException( 438 "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); 439 } 440 441 static class Tensor_uint8 extends Tensor { 442 private final ByteBuffer data; 443 Tensor_uint8(ByteBuffer data, long[] shape)444 private Tensor_uint8(ByteBuffer data, long[] shape) { 445 super(shape); 446 this.data = data; 447 } 448 449 @Override dtype()450 public DType dtype() { 451 return DType.UINT8; 452 } 453 454 @Override getRawDataBuffer()455 Buffer getRawDataBuffer() { 456 return data; 457 } 458 459 @Override getDataAsUnsignedByteArray()460 public byte[] getDataAsUnsignedByteArray() { 461 data.rewind(); 462 byte[] arr = new byte[data.remaining()]; 463 data.get(arr); 464 return arr; 465 } 466 467 @Override toString()468 public String toString() { 469 return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape)); 470 } 471 } 472 473 static class Tensor_int8 extends Tensor { 474 private final ByteBuffer data; 475 Tensor_int8(ByteBuffer data, long[] shape)476 private Tensor_int8(ByteBuffer data, long[] shape) { 477 super(shape); 478 this.data = data; 479 } 480 481 @Override dtype()482 public DType dtype() { 483 return DType.INT8; 484 } 485 486 @Override getRawDataBuffer()487 Buffer getRawDataBuffer() { 488 return data; 489 } 490 491 @Override getDataAsByteArray()492 public byte[] getDataAsByteArray() { 493 data.rewind(); 494 byte[] arr = new byte[data.remaining()]; 495 data.get(arr); 496 return arr; 497 } 498 499 @Override toString()500 public String toString() { 501 return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape)); 502 } 503 } 504 505 static class Tensor_int32 extends Tensor { 506 private final IntBuffer data; 507 Tensor_int32(IntBuffer data, long[] shape)508 private Tensor_int32(IntBuffer data, long[] shape) { 509 super(shape); 510 this.data = data; 511 } 512 513 @Override dtype()514 public DType dtype() { 515 return DType.INT32; 516 } 517 518 @Override getRawDataBuffer()519 Buffer getRawDataBuffer() { 520 return data; 521 } 522 523 @Override getDataAsIntArray()524 public int[] getDataAsIntArray() { 525 data.rewind(); 526 int[] arr = new int[data.remaining()]; 527 data.get(arr); 528 return arr; 529 } 530 531 @Override toString()532 public String toString() { 533 return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape)); 534 } 535 } 536 537 static class Tensor_float32 extends Tensor { 538 private final FloatBuffer data; 539 Tensor_float32(FloatBuffer data, long[] shape)540 Tensor_float32(FloatBuffer data, long[] shape) { 541 super(shape); 542 this.data = data; 543 } 544 545 @Override getDataAsFloatArray()546 public float[] getDataAsFloatArray() { 547 data.rewind(); 548 float[] arr = new float[data.remaining()]; 549 data.get(arr); 550 return arr; 551 } 552 553 @Override dtype()554 public DType dtype() { 555 return DType.FLOAT; 556 } 557 558 @Override getRawDataBuffer()559 Buffer getRawDataBuffer() { 560 return data; 561 } 562 563 @Override toString()564 public String toString() { 565 return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape)); 566 } 567 } 568 569 static class Tensor_int64 extends Tensor { 570 private final LongBuffer data; 571 Tensor_int64(LongBuffer data, long[] shape)572 private Tensor_int64(LongBuffer data, long[] shape) { 573 super(shape); 574 this.data = data; 575 } 576 577 @Override dtype()578 public DType dtype() { 579 return DType.INT64; 580 } 581 582 @Override getRawDataBuffer()583 Buffer getRawDataBuffer() { 584 return data; 585 } 586 587 @Override getDataAsLongArray()588 public long[] getDataAsLongArray() { 589 data.rewind(); 590 long[] arr = new long[data.remaining()]; 591 data.get(arr); 592 return arr; 593 } 594 595 @Override toString()596 public String toString() { 597 return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape)); 598 } 599 } 600 601 static class Tensor_float64 extends Tensor { 602 private final DoubleBuffer data; 603 Tensor_float64(DoubleBuffer data, long[] shape)604 private Tensor_float64(DoubleBuffer data, long[] shape) { 605 super(shape); 606 this.data = data; 607 } 608 609 @Override dtype()610 public DType dtype() { 611 return DType.DOUBLE; 612 } 613 614 @Override getRawDataBuffer()615 Buffer getRawDataBuffer() { 616 return data; 617 } 618 619 @Override getDataAsDoubleArray()620 public double[] getDataAsDoubleArray() { 621 data.rewind(); 622 double[] arr = new double[data.remaining()]; 623 data.get(arr); 624 return arr; 625 } 626 627 @Override toString()628 public String toString() { 629 return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape)); 630 } 631 } 632 633 // region checks checkArgument(boolean expression, String errorMessage, Object... args)634 private static void checkArgument(boolean expression, String errorMessage, Object... args) { 635 if (!expression) { 636 throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args)); 637 } 638 } 639 checkShape(long[] shape)640 private static void checkShape(long[] shape) { 641 checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); 642 for (int i = 0; i < shape.length; i++) { 643 checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); 644 } 645 } 646 checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape)647 private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) { 648 final long numel = numel(shape); 649 checkArgument( 650 numel == dataCapacity, 651 "Inconsistent data capacity:%d and shape number elements:%d shape:%s", 652 dataCapacity, 653 numel, 654 Arrays.toString(shape)); 655 } 656 657 // endregion checks 658 659 // Called from native 660 @DoNotStrip nativeNewTensor( ByteBuffer data, long[] shape, int dtype, HybridData hybridData)661 private static Tensor nativeNewTensor( 662 ByteBuffer data, long[] shape, int dtype, HybridData hybridData) { 663 Tensor tensor = null; 664 665 if (DType.FLOAT.jniCode == dtype) { 666 tensor = new Tensor_float32(data.asFloatBuffer(), shape); 667 } else if (DType.INT32.jniCode == dtype) { 668 tensor = new Tensor_int32(data.asIntBuffer(), shape); 669 } else if (DType.INT64.jniCode == dtype) { 670 tensor = new Tensor_int64(data.asLongBuffer(), shape); 671 } else if (DType.DOUBLE.jniCode == dtype) { 672 tensor = new Tensor_float64(data.asDoubleBuffer(), shape); 673 } else if (DType.UINT8.jniCode == dtype) { 674 tensor = new Tensor_uint8(data, shape); 675 } else if (DType.INT8.jniCode == dtype) { 676 tensor = new Tensor_int8(data, shape); 677 } else { 678 throw new IllegalArgumentException("Unknown Tensor dtype"); 679 } 680 tensor.mHybridData = hybridData; 681 return tensor; 682 } 683 684 /** 685 * Serializes a {@code Tensor} into a byte array. 686 * 687 * @return The serialized byte array. 688 * @apiNote This method is experimental and subject to change without notice. This does NOT 689 * supoprt list type. 690 */ toByteArray()691 public byte[] toByteArray() { 692 int dtypeSize = 0; 693 byte[] tensorAsByteArray = null; 694 if (dtype() == DType.UINT8) { 695 dtypeSize = BYTE_SIZE_BYTES; 696 tensorAsByteArray = new byte[(int) numel()]; 697 Tensor_uint8 thiz = (Tensor_uint8) this; 698 ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray()); 699 } else if (dtype() == DType.INT8) { 700 dtypeSize = BYTE_SIZE_BYTES; 701 tensorAsByteArray = new byte[(int) numel()]; 702 Tensor_int8 thiz = (Tensor_int8) this; 703 ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray()); 704 } else if (dtype() == DType.INT16) { 705 throw new IllegalArgumentException("DType.INT16 is not supported in Java so far"); 706 } else if (dtype() == DType.INT32) { 707 dtypeSize = INT_SIZE_BYTES; 708 tensorAsByteArray = new byte[(int) numel() * dtypeSize]; 709 Tensor_int32 thiz = (Tensor_int32) this; 710 ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray()); 711 } else if (dtype() == DType.INT64) { 712 dtypeSize = LONG_SIZE_BYTES; 713 tensorAsByteArray = new byte[(int) numel() * dtypeSize]; 714 Tensor_int64 thiz = (Tensor_int64) this; 715 ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray()); 716 } else if (dtype() == DType.FLOAT) { 717 dtypeSize = FLOAT_SIZE_BYTES; 718 tensorAsByteArray = new byte[(int) numel() * dtypeSize]; 719 Tensor_float32 thiz = (Tensor_float32) this; 720 ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray()); 721 } else if (dtype() == DType.DOUBLE) { 722 dtypeSize = DOUBLE_SIZE_BYTES; 723 tensorAsByteArray = new byte[(int) numel() * dtypeSize]; 724 Tensor_float64 thiz = (Tensor_float64) this; 725 ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray()); 726 } else { 727 throw new IllegalArgumentException("Unknown Tensor dtype"); 728 } 729 ByteBuffer byteBuffer = 730 ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel()); 731 byteBuffer.put((byte) dtype().jniCode); 732 byteBuffer.put((byte) shape.length); 733 for (long s : shape) { 734 byteBuffer.putInt((int) s); 735 } 736 byteBuffer.put(tensorAsByteArray); 737 return byteBuffer.array(); 738 } 739 740 /** 741 * Deserializes a {@code Tensor} from a byte[]. 742 * 743 * @param buffer The byte array to deserialize from. 744 * @return The deserialized {@code Tensor}. 745 * @apiNote This method is experimental and subject to change without notice. This does NOT 746 * supoprt list type. 747 */ fromByteArray(byte[] bytes)748 public static Tensor fromByteArray(byte[] bytes) { 749 if (bytes == null) { 750 throw new IllegalArgumentException("bytes cannot be null"); 751 } 752 ByteBuffer buffer = ByteBuffer.wrap(bytes); 753 if (!buffer.hasRemaining()) { 754 throw new IllegalArgumentException("invalid buffer"); 755 } 756 byte dtype = buffer.get(); 757 byte shapeLength = buffer.get(); 758 long[] shape = new long[(int) shapeLength]; 759 long numel = 1; 760 for (int i = 0; i < shapeLength; i++) { 761 int dim = buffer.getInt(); 762 if (dim < 0) { 763 throw new IllegalArgumentException("invalid shape"); 764 } 765 shape[i] = dim; 766 numel *= dim; 767 } 768 if (dtype == DType.UINT8.jniCode) { 769 return new Tensor_uint8(buffer, shape); 770 } else if (dtype == DType.INT8.jniCode) { 771 return new Tensor_int8(buffer, shape); 772 } else if (dtype == DType.INT32.jniCode) { 773 return new Tensor_int32(buffer.asIntBuffer(), shape); 774 } else if (dtype == DType.INT64.jniCode) { 775 return new Tensor_int64(buffer.asLongBuffer(), shape); 776 } else if (dtype == DType.FLOAT.jniCode) { 777 return new Tensor_float32(buffer.asFloatBuffer(), shape); 778 } else if (dtype == DType.DOUBLE.jniCode) { 779 return new Tensor_float64(buffer.asDoubleBuffer(), shape); 780 } else { 781 throw new IllegalArgumentException("Unknown Tensor dtype"); 782 } 783 } 784 } 785