• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.executorch;
10 
11 import org.pytorch.executorch.annotations.Experimental;
12 
13 /**
14  * Codes representing tensor data types.
15  *
16  * <p>Warning: These APIs are experimental and subject to change without notice
17  */
18 @Experimental
19 public enum DType {
20   // NOTE: "jniCode" must be kept in sync with scalar_type.h.
21   // NOTE: Never serialize "jniCode", because it can change between releases.
22 
23   /** Code for dtype ScalarType::Byte */
24   UINT8(0),
25   /** Code for dtype ScalarType::Char */
26   INT8(1),
27   /** Code for dtype ScalarType::Short */
28   INT16(2),
29   /** Code for dtype ScalarType::Int */
30   INT32(3),
31   /** Code for dtype ScalarType::Long */
32   INT64(4),
33   /** Code for dtype ScalarType::Half */
34   HALF(5),
35   /** Code for dtype ScalarType::Float */
36   FLOAT(6),
37   /** Code for dtype ScalarType::Double */
38   DOUBLE(7),
39   /** Code for dtype ScalarType::ComplexHalf */
40   COMPLEX_HALF(8),
41   /** Code for dtype ScalarType::ComplexFloat */
42   COMPLEX_FLOAT(9),
43   /** Code for dtype ScalarType::ComplexDouble */
44   COMPLEX_DOUBLE(10),
45   /** Code for dtype ScalarType::Bool */
46   BOOL(11),
47   /** Code for dtype ScalarType::QInt8 */
48   QINT8(12),
49   /** Code for dtype ScalarType::QUInt8 */
50   QUINT8(13),
51   /** Code for dtype ScalarType::QInt32 */
52   QINT32(14),
53   /** Code for dtype ScalarType::BFloat16 */
54   BFLOAT16(15),
55   /** Code for dtype ScalarType::QUInt4x2 */
56   QINT4X2(16),
57   /** Code for dtype ScalarType::QUInt2x4 */
58   QINT2X4(17),
59   /** Code for dtype ScalarType::Bits1x8 */
60   BITS1X8(18),
61   /** Code for dtype ScalarType::Bits2x4 */
62   BITS2X4(19),
63   /** Code for dtype ScalarType::Bits4x2 */
64   BITS4X2(20),
65   /** Code for dtype ScalarType::Bits8 */
66   BITS8(21),
67   /** Code for dtype ScalarType::Bits16 */
68   BITS16(22),
69   ;
70 
71   final int jniCode;
72 
DType(int jniCode)73   DType(int jniCode) {
74     this.jniCode = jniCode;
75   }
76 }
77