• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.executorch;
10 
11 import com.facebook.jni.HybridData;
12 import com.facebook.jni.annotations.DoNotStrip;
13 import java.nio.Buffer;
14 import java.nio.ByteBuffer;
15 import java.nio.ByteOrder;
16 import java.nio.DoubleBuffer;
17 import java.nio.FloatBuffer;
18 import java.nio.IntBuffer;
19 import java.nio.LongBuffer;
20 import java.util.Arrays;
21 import java.util.Locale;
22 import org.pytorch.executorch.annotations.Experimental;
23 
24 /**
25  * Representation of an ExecuTorch Tensor. Behavior is similar to PyTorch's tensor objects.
26  *
27  * <p>Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data}
28  * can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided
29  * to allocate buffers properly.
30  *
31  * <p>To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*}
32  * methods.
33  *
34  * <p>When constructing {@code Tensor} objects with {@code data} as an array, it is not specified
35  * whether this data is copied or retained as a reference so it is recommended not to modify it
36  * after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified
37  * between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects
38  * may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape}
39  * is always copied.
40  *
41  * <p>Warning: These APIs are experimental and subject to change without notice
42  */
43 @Experimental
44 public abstract class Tensor {
45   private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null";
46   private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null";
47   private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null";
48   private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative";
49   private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER =
50       "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)";
51   private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
52       "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
53 
54   @DoNotStrip final long[] shape;
55 
56   private static final int BYTE_SIZE_BYTES = 1;
57   private static final int INT_SIZE_BYTES = 4;
58   private static final int LONG_SIZE_BYTES = 8;
59   private static final int FLOAT_SIZE_BYTES = 4;
60   private static final int DOUBLE_SIZE_BYTES = 8;
61 
62   /**
63    * Allocates a new direct {@link ByteBuffer} with native byte order with specified capacity that
64    * can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link
65    * Tensor#fromBlobUnsigned(ByteBuffer, long[])}.
66    *
67    * @param numElements capacity (number of elements) of result buffer.
68    */
allocateByteBuffer(int numElements)69   public static ByteBuffer allocateByteBuffer(int numElements) {
70     return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder());
71   }
72 
73   /**
74    * Allocates a new direct {@link IntBuffer} with native byte order with specified capacity that
75    * can be used in {@link Tensor#fromBlob(IntBuffer, long[])}.
76    *
77    * @param numElements capacity (number of elements) of result buffer.
78    */
allocateIntBuffer(int numElements)79   public static IntBuffer allocateIntBuffer(int numElements) {
80     return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES)
81         .order(ByteOrder.nativeOrder())
82         .asIntBuffer();
83   }
84 
85   /**
86    * Allocates a new direct {@link FloatBuffer} with native byte order with specified capacity that
87    * can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}.
88    *
89    * @param numElements capacity (number of elements) of result buffer.
90    */
allocateFloatBuffer(int numElements)91   public static FloatBuffer allocateFloatBuffer(int numElements) {
92     return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES)
93         .order(ByteOrder.nativeOrder())
94         .asFloatBuffer();
95   }
96 
97   /**
98    * Allocates a new direct {@link LongBuffer} with native byte order with specified capacity that
99    * can be used in {@link Tensor#fromBlob(LongBuffer, long[])}.
100    *
101    * @param numElements capacity (number of elements) of result buffer.
102    */
allocateLongBuffer(int numElements)103   public static LongBuffer allocateLongBuffer(int numElements) {
104     return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES)
105         .order(ByteOrder.nativeOrder())
106         .asLongBuffer();
107   }
108 
109   /**
110    * Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that
111    * can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
112    *
113    * @param numElements capacity (number of elements) of result buffer.
114    */
allocateDoubleBuffer(int numElements)115   public static DoubleBuffer allocateDoubleBuffer(int numElements) {
116     return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES)
117         .order(ByteOrder.nativeOrder())
118         .asDoubleBuffer();
119   }
120 
121   /**
122    * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of
123    * bytes.
124    *
125    * @param data Tensor elements
126    * @param shape Tensor shape
127    */
fromBlobUnsigned(byte[] data, long[] shape)128   public static Tensor fromBlobUnsigned(byte[] data, long[] shape) {
129     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
130     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
131     checkShape(shape);
132     checkShapeAndDataCapacityConsistency(data.length, shape);
133     final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
134     byteBuffer.put(data);
135     return new Tensor_uint8(byteBuffer, shape);
136   }
137 
138   /**
139    * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of
140    * bytes.
141    *
142    * @param data Tensor elements
143    * @param shape Tensor shape
144    */
fromBlob(byte[] data, long[] shape)145   public static Tensor fromBlob(byte[] data, long[] shape) {
146     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
147     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
148     checkShape(shape);
149     checkShapeAndDataCapacityConsistency(data.length, shape);
150     final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape));
151     byteBuffer.put(data);
152     return new Tensor_int8(byteBuffer, shape);
153   }
154 
155   /**
156    * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of
157    * ints.
158    *
159    * @param data Tensor elements
160    * @param shape Tensor shape
161    */
fromBlob(int[] data, long[] shape)162   public static Tensor fromBlob(int[] data, long[] shape) {
163     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
164     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
165     checkShape(shape);
166     checkShapeAndDataCapacityConsistency(data.length, shape);
167     final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape));
168     intBuffer.put(data);
169     return new Tensor_int32(intBuffer, shape);
170   }
171 
172   /**
173    * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array
174    * of floats.
175    *
176    * @param data Tensor elements
177    * @param shape Tensor shape
178    */
fromBlob(float[] data, long[] shape)179   public static Tensor fromBlob(float[] data, long[] shape) {
180     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
181     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
182     checkShape(shape);
183     checkShapeAndDataCapacityConsistency(data.length, shape);
184     final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape));
185     floatBuffer.put(data);
186     return new Tensor_float32(floatBuffer, shape);
187   }
188 
189   /**
190    * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
191    * longs.
192    *
193    * @param data Tensor elements
194    * @param shape Tensor shape
195    */
fromBlob(long[] data, long[] shape)196   public static Tensor fromBlob(long[] data, long[] shape) {
197     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
198     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
199     checkShape(shape);
200     checkShapeAndDataCapacityConsistency(data.length, shape);
201     final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape));
202     longBuffer.put(data);
203     return new Tensor_int64(longBuffer, shape);
204   }
205 
206   /**
207    * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array
208    * of doubles.
209    *
210    * @param shape Tensor shape
211    * @param data Tensor elements
212    */
fromBlob(double[] data, long[] shape)213   public static Tensor fromBlob(double[] data, long[] shape) {
214     checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
215     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
216     checkShape(shape);
217     checkShapeAndDataCapacityConsistency(data.length, shape);
218     final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape));
219     doubleBuffer.put(data);
220     return new Tensor_float64(doubleBuffer, shape);
221   }
222 
223   /**
224    * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data.
225    *
226    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
227    *     elements. The buffer is used directly without copying, and changes to its content will
228    *     change the tensor.
229    * @param shape Tensor shape
230    */
fromBlobUnsigned(ByteBuffer data, long[] shape)231   public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) {
232     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
233     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
234     checkShape(shape);
235     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
236     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
237     checkArgument(
238         (data.order() == ByteOrder.nativeOrder()),
239         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
240     return new Tensor_uint8(data, shape);
241   }
242 
243   /**
244    * Creates a new Tensor instance with dtype torch.int8 with specified shape and data.
245    *
246    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
247    *     elements. The buffer is used directly without copying, and changes to its content will
248    *     change the tensor.
249    * @param shape Tensor shape
250    */
fromBlob(ByteBuffer data, long[] shape)251   public static Tensor fromBlob(ByteBuffer data, long[] shape) {
252     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
253     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
254     checkShape(shape);
255     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
256     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
257     checkArgument(
258         (data.order() == ByteOrder.nativeOrder()),
259         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
260     return new Tensor_int8(data, shape);
261   }
262 
263   /**
264    * Creates a new Tensor instance with dtype torch.int32 with specified shape and data.
265    *
266    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
267    *     elements. The buffer is used directly without copying, and changes to its content will
268    *     change the tensor.
269    * @param shape Tensor shape
270    */
fromBlob(IntBuffer data, long[] shape)271   public static Tensor fromBlob(IntBuffer data, long[] shape) {
272     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
273     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
274     checkShape(shape);
275     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
276     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
277     checkArgument(
278         (data.order() == ByteOrder.nativeOrder()),
279         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
280     return new Tensor_int32(data, shape);
281   }
282 
283   /**
284    * Creates a new Tensor instance with dtype torch.float32 with specified shape and data.
285    *
286    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
287    *     elements. The buffer is used directly without copying, and changes to its content will
288    *     change the tensor.
289    * @param shape Tensor shape
290    */
fromBlob(FloatBuffer data, long[] shape)291   public static Tensor fromBlob(FloatBuffer data, long[] shape) {
292     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
293     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
294     checkShape(shape);
295     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
296     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
297     checkArgument(
298         (data.order() == ByteOrder.nativeOrder()),
299         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
300     return new Tensor_float32(data, shape);
301   }
302 
303   /**
304    * Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
305    *
306    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
307    *     elements. The buffer is used directly without copying, and changes to its content will
308    *     change the tensor.
309    * @param shape Tensor shape
310    */
fromBlob(LongBuffer data, long[] shape)311   public static Tensor fromBlob(LongBuffer data, long[] shape) {
312     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
313     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
314     checkShape(shape);
315     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
316     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
317     checkArgument(
318         (data.order() == ByteOrder.nativeOrder()),
319         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
320     return new Tensor_int64(data, shape);
321   }
322 
323   /**
324    * Creates a new Tensor instance with dtype torch.float64 with specified shape and data.
325    *
326    * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
327    *     elements. The buffer is used directly without copying, and changes to its content will
328    *     change the tensor.
329    * @param shape Tensor shape
330    */
fromBlob(DoubleBuffer data, long[] shape)331   public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
332     checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
333     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
334     checkShape(shape);
335     checkShapeAndDataCapacityConsistency(data.capacity(), shape);
336     checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
337     checkArgument(
338         (data.order() == ByteOrder.nativeOrder()),
339         ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
340     return new Tensor_float64(data, shape);
341   }
342 
343   @DoNotStrip private HybridData mHybridData;
344 
Tensor(long[] shape)345   private Tensor(long[] shape) {
346     checkShape(shape);
347     this.shape = Arrays.copyOf(shape, shape.length);
348   }
349 
350   /** Returns the number of elements in this tensor. */
numel()351   public long numel() {
352     return numel(this.shape);
353   }
354 
355   /** Calculates the number of elements in a tensor with the specified shape. */
numel(long[] shape)356   public static long numel(long[] shape) {
357     checkShape(shape);
358     int result = 1;
359     for (long s : shape) {
360       result *= s;
361     }
362     return result;
363   }
364 
365   /** Returns the shape of this tensor. (The array is a fresh copy.) */
shape()366   public long[] shape() {
367     return Arrays.copyOf(shape, shape.length);
368   }
369 
370   /**
371    * @return data type of this tensor.
372    */
dtype()373   public abstract DType dtype();
374 
375   // Called from native
376   @DoNotStrip
dtypeJniCode()377   int dtypeJniCode() {
378     return dtype().jniCode;
379   }
380 
381   /**
382    * @return a Java byte array that contains the tensor data. This may be a copy or reference.
383    * @throws IllegalStateException if it is called for a non-int8 tensor.
384    */
getDataAsByteArray()385   public byte[] getDataAsByteArray() {
386     throw new IllegalStateException(
387         "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
388   }
389 
390   /**
391    * @return a Java byte array that contains the tensor data. This may be a copy or reference.
392    * @throws IllegalStateException if it is called for a non-uint8 tensor.
393    */
getDataAsUnsignedByteArray()394   public byte[] getDataAsUnsignedByteArray() {
395     throw new IllegalStateException(
396         "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
397   }
398 
399   /**
400    * @return a Java int array that contains the tensor data. This may be a copy or reference.
401    * @throws IllegalStateException if it is called for a non-int32 tensor.
402    */
getDataAsIntArray()403   public int[] getDataAsIntArray() {
404     throw new IllegalStateException(
405         "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array.");
406   }
407 
408   /**
409    * @return a Java float array that contains the tensor data. This may be a copy or reference.
410    * @throws IllegalStateException if it is called for a non-float32 tensor.
411    */
getDataAsFloatArray()412   public float[] getDataAsFloatArray() {
413     throw new IllegalStateException(
414         "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
415   }
416 
417   /**
418    * @return a Java long array that contains the tensor data. This may be a copy or reference.
419    * @throws IllegalStateException if it is called for a non-int64 tensor.
420    */
getDataAsLongArray()421   public long[] getDataAsLongArray() {
422     throw new IllegalStateException(
423         "Tensor of type " + getClass().getSimpleName() + " cannot return data as long array.");
424   }
425 
426   /**
427    * @return a Java double array that contains the tensor data. This may be a copy or reference.
428    * @throws IllegalStateException if it is called for a non-float64 tensor.
429    */
getDataAsDoubleArray()430   public double[] getDataAsDoubleArray() {
431     throw new IllegalStateException(
432         "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
433   }
434 
435   @DoNotStrip
getRawDataBuffer()436   Buffer getRawDataBuffer() {
437     throw new IllegalStateException(
438         "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
439   }
440 
441   static class Tensor_uint8 extends Tensor {
442     private final ByteBuffer data;
443 
Tensor_uint8(ByteBuffer data, long[] shape)444     private Tensor_uint8(ByteBuffer data, long[] shape) {
445       super(shape);
446       this.data = data;
447     }
448 
449     @Override
dtype()450     public DType dtype() {
451       return DType.UINT8;
452     }
453 
454     @Override
getRawDataBuffer()455     Buffer getRawDataBuffer() {
456       return data;
457     }
458 
459     @Override
getDataAsUnsignedByteArray()460     public byte[] getDataAsUnsignedByteArray() {
461       data.rewind();
462       byte[] arr = new byte[data.remaining()];
463       data.get(arr);
464       return arr;
465     }
466 
467     @Override
toString()468     public String toString() {
469       return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
470     }
471   }
472 
473   static class Tensor_int8 extends Tensor {
474     private final ByteBuffer data;
475 
Tensor_int8(ByteBuffer data, long[] shape)476     private Tensor_int8(ByteBuffer data, long[] shape) {
477       super(shape);
478       this.data = data;
479     }
480 
481     @Override
dtype()482     public DType dtype() {
483       return DType.INT8;
484     }
485 
486     @Override
getRawDataBuffer()487     Buffer getRawDataBuffer() {
488       return data;
489     }
490 
491     @Override
getDataAsByteArray()492     public byte[] getDataAsByteArray() {
493       data.rewind();
494       byte[] arr = new byte[data.remaining()];
495       data.get(arr);
496       return arr;
497     }
498 
499     @Override
toString()500     public String toString() {
501       return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
502     }
503   }
504 
505   static class Tensor_int32 extends Tensor {
506     private final IntBuffer data;
507 
Tensor_int32(IntBuffer data, long[] shape)508     private Tensor_int32(IntBuffer data, long[] shape) {
509       super(shape);
510       this.data = data;
511     }
512 
513     @Override
dtype()514     public DType dtype() {
515       return DType.INT32;
516     }
517 
518     @Override
getRawDataBuffer()519     Buffer getRawDataBuffer() {
520       return data;
521     }
522 
523     @Override
getDataAsIntArray()524     public int[] getDataAsIntArray() {
525       data.rewind();
526       int[] arr = new int[data.remaining()];
527       data.get(arr);
528       return arr;
529     }
530 
531     @Override
toString()532     public String toString() {
533       return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
534     }
535   }
536 
537   static class Tensor_float32 extends Tensor {
538     private final FloatBuffer data;
539 
Tensor_float32(FloatBuffer data, long[] shape)540     Tensor_float32(FloatBuffer data, long[] shape) {
541       super(shape);
542       this.data = data;
543     }
544 
545     @Override
getDataAsFloatArray()546     public float[] getDataAsFloatArray() {
547       data.rewind();
548       float[] arr = new float[data.remaining()];
549       data.get(arr);
550       return arr;
551     }
552 
553     @Override
dtype()554     public DType dtype() {
555       return DType.FLOAT;
556     }
557 
558     @Override
getRawDataBuffer()559     Buffer getRawDataBuffer() {
560       return data;
561     }
562 
563     @Override
toString()564     public String toString() {
565       return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape));
566     }
567   }
568 
569   static class Tensor_int64 extends Tensor {
570     private final LongBuffer data;
571 
Tensor_int64(LongBuffer data, long[] shape)572     private Tensor_int64(LongBuffer data, long[] shape) {
573       super(shape);
574       this.data = data;
575     }
576 
577     @Override
dtype()578     public DType dtype() {
579       return DType.INT64;
580     }
581 
582     @Override
getRawDataBuffer()583     Buffer getRawDataBuffer() {
584       return data;
585     }
586 
587     @Override
getDataAsLongArray()588     public long[] getDataAsLongArray() {
589       data.rewind();
590       long[] arr = new long[data.remaining()];
591       data.get(arr);
592       return arr;
593     }
594 
595     @Override
toString()596     public String toString() {
597       return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
598     }
599   }
600 
601   static class Tensor_float64 extends Tensor {
602     private final DoubleBuffer data;
603 
Tensor_float64(DoubleBuffer data, long[] shape)604     private Tensor_float64(DoubleBuffer data, long[] shape) {
605       super(shape);
606       this.data = data;
607     }
608 
609     @Override
dtype()610     public DType dtype() {
611       return DType.DOUBLE;
612     }
613 
614     @Override
getRawDataBuffer()615     Buffer getRawDataBuffer() {
616       return data;
617     }
618 
619     @Override
getDataAsDoubleArray()620     public double[] getDataAsDoubleArray() {
621       data.rewind();
622       double[] arr = new double[data.remaining()];
623       data.get(arr);
624       return arr;
625     }
626 
627     @Override
toString()628     public String toString() {
629       return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
630     }
631   }
632 
633   // region checks
checkArgument(boolean expression, String errorMessage, Object... args)634   private static void checkArgument(boolean expression, String errorMessage, Object... args) {
635     if (!expression) {
636       throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args));
637     }
638   }
639 
checkShape(long[] shape)640   private static void checkShape(long[] shape) {
641     checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
642     for (int i = 0; i < shape.length; i++) {
643       checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE);
644     }
645   }
646 
checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape)647   private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) {
648     final long numel = numel(shape);
649     checkArgument(
650         numel == dataCapacity,
651         "Inconsistent data capacity:%d and shape number elements:%d shape:%s",
652         dataCapacity,
653         numel,
654         Arrays.toString(shape));
655   }
656 
657   // endregion checks
658 
659   // Called from native
660   @DoNotStrip
nativeNewTensor( ByteBuffer data, long[] shape, int dtype, HybridData hybridData)661   private static Tensor nativeNewTensor(
662       ByteBuffer data, long[] shape, int dtype, HybridData hybridData) {
663     Tensor tensor = null;
664 
665     if (DType.FLOAT.jniCode == dtype) {
666       tensor = new Tensor_float32(data.asFloatBuffer(), shape);
667     } else if (DType.INT32.jniCode == dtype) {
668       tensor = new Tensor_int32(data.asIntBuffer(), shape);
669     } else if (DType.INT64.jniCode == dtype) {
670       tensor = new Tensor_int64(data.asLongBuffer(), shape);
671     } else if (DType.DOUBLE.jniCode == dtype) {
672       tensor = new Tensor_float64(data.asDoubleBuffer(), shape);
673     } else if (DType.UINT8.jniCode == dtype) {
674       tensor = new Tensor_uint8(data, shape);
675     } else if (DType.INT8.jniCode == dtype) {
676       tensor = new Tensor_int8(data, shape);
677     } else {
678       throw new IllegalArgumentException("Unknown Tensor dtype");
679     }
680     tensor.mHybridData = hybridData;
681     return tensor;
682   }
683 
684   /**
685    * Serializes a {@code Tensor} into a byte array.
686    *
687    * @return The serialized byte array.
688    * @apiNote This method is experimental and subject to change without notice. This does NOT
689    *     supoprt list type.
690    */
toByteArray()691   public byte[] toByteArray() {
692     int dtypeSize = 0;
693     byte[] tensorAsByteArray = null;
694     if (dtype() == DType.UINT8) {
695       dtypeSize = BYTE_SIZE_BYTES;
696       tensorAsByteArray = new byte[(int) numel()];
697       Tensor_uint8 thiz = (Tensor_uint8) this;
698       ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray());
699     } else if (dtype() == DType.INT8) {
700       dtypeSize = BYTE_SIZE_BYTES;
701       tensorAsByteArray = new byte[(int) numel()];
702       Tensor_int8 thiz = (Tensor_int8) this;
703       ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
704     } else if (dtype() == DType.INT16) {
705       throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
706     } else if (dtype() == DType.INT32) {
707       dtypeSize = INT_SIZE_BYTES;
708       tensorAsByteArray = new byte[(int) numel() * dtypeSize];
709       Tensor_int32 thiz = (Tensor_int32) this;
710       ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray());
711     } else if (dtype() == DType.INT64) {
712       dtypeSize = LONG_SIZE_BYTES;
713       tensorAsByteArray = new byte[(int) numel() * dtypeSize];
714       Tensor_int64 thiz = (Tensor_int64) this;
715       ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray());
716     } else if (dtype() == DType.FLOAT) {
717       dtypeSize = FLOAT_SIZE_BYTES;
718       tensorAsByteArray = new byte[(int) numel() * dtypeSize];
719       Tensor_float32 thiz = (Tensor_float32) this;
720       ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray());
721     } else if (dtype() == DType.DOUBLE) {
722       dtypeSize = DOUBLE_SIZE_BYTES;
723       tensorAsByteArray = new byte[(int) numel() * dtypeSize];
724       Tensor_float64 thiz = (Tensor_float64) this;
725       ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray());
726     } else {
727       throw new IllegalArgumentException("Unknown Tensor dtype");
728     }
729     ByteBuffer byteBuffer =
730         ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel());
731     byteBuffer.put((byte) dtype().jniCode);
732     byteBuffer.put((byte) shape.length);
733     for (long s : shape) {
734       byteBuffer.putInt((int) s);
735     }
736     byteBuffer.put(tensorAsByteArray);
737     return byteBuffer.array();
738   }
739 
740   /**
741    * Deserializes a {@code Tensor} from a byte[].
742    *
743    * @param buffer The byte array to deserialize from.
744    * @return The deserialized {@code Tensor}.
745    * @apiNote This method is experimental and subject to change without notice. This does NOT
746    *     supoprt list type.
747    */
fromByteArray(byte[] bytes)748   public static Tensor fromByteArray(byte[] bytes) {
749     if (bytes == null) {
750       throw new IllegalArgumentException("bytes cannot be null");
751     }
752     ByteBuffer buffer = ByteBuffer.wrap(bytes);
753     if (!buffer.hasRemaining()) {
754       throw new IllegalArgumentException("invalid buffer");
755     }
756     byte dtype = buffer.get();
757     byte shapeLength = buffer.get();
758     long[] shape = new long[(int) shapeLength];
759     long numel = 1;
760     for (int i = 0; i < shapeLength; i++) {
761       int dim = buffer.getInt();
762       if (dim < 0) {
763         throw new IllegalArgumentException("invalid shape");
764       }
765       shape[i] = dim;
766       numel *= dim;
767     }
768     if (dtype == DType.UINT8.jniCode) {
769       return new Tensor_uint8(buffer, shape);
770     } else if (dtype == DType.INT8.jniCode) {
771       return new Tensor_int8(buffer, shape);
772     } else if (dtype == DType.INT32.jniCode) {
773       return new Tensor_int32(buffer.asIntBuffer(), shape);
774     } else if (dtype == DType.INT64.jniCode) {
775       return new Tensor_int64(buffer.asLongBuffer(), shape);
776     } else if (dtype == DType.FLOAT.jniCode) {
777       return new Tensor_float32(buffer.asFloatBuffer(), shape);
778     } else if (dtype == DType.DOUBLE.jniCode) {
779       return new Tensor_float64(buffer.asDoubleBuffer(), shape);
780     } else {
781       throw new IllegalArgumentException("Unknown Tensor dtype");
782     }
783   }
784 }
785