• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.executorch;
10 
11 import com.facebook.jni.annotations.DoNotStrip;
12 import java.nio.ByteBuffer;
13 import java.util.Arrays;
14 import java.util.Locale;
15 import org.pytorch.executorch.annotations.Experimental;
16 
17 /**
18  * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one
19  * of the supported types: https://pytorch.org/docs/stable/jit.html#types .
20  *
21  * <p>Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}.
22  *
23  * <p>{@code EValue} objects are constructed with {@code EValue.from(value)}, {@code
24  * EValue.tupleFrom(value1, value2, ...)}, {@code EValue.listFrom(value1, value2, ...)}, or one of
25  * the {@code dict} methods, depending on the key type.
26  *
27  * <p>Data is retrieved from {@code EValue} objects with the {@code toX()} methods. Note that {@code
28  * str}-type EValues must be extracted with {@link #toStr()}, rather than {@link #toString()}.
29  *
30  * <p>{@code EValue} objects may retain references to objects passed into their constructors, and
31  * may return references to their internal state from {@code toX()}.
32  *
33  * <p>Warning: These APIs are experimental and subject to change without notice
34  */
35 @Experimental
36 @DoNotStrip
37 public class EValue {
38   private static final int TYPE_CODE_NONE = 0;
39 
40   private static final int TYPE_CODE_TENSOR = 1;
41   private static final int TYPE_CODE_STRING = 2;
42   private static final int TYPE_CODE_DOUBLE = 3;
43   private static final int TYPE_CODE_INT = 4;
44   private static final int TYPE_CODE_BOOL = 5;
45 
46   private String[] TYPE_NAMES = {
47     "None", "Tensor", "String", "Double", "Int", "Bool",
48   };
49 
50   @DoNotStrip private final int mTypeCode;
51   @DoNotStrip private Object mData;
52 
53   @DoNotStrip
EValue(int typeCode)54   private EValue(int typeCode) {
55     this.mTypeCode = typeCode;
56   }
57 
58   @DoNotStrip
isNone()59   public boolean isNone() {
60     return TYPE_CODE_NONE == this.mTypeCode;
61   }
62 
63   @DoNotStrip
isTensor()64   public boolean isTensor() {
65     return TYPE_CODE_TENSOR == this.mTypeCode;
66   }
67 
68   @DoNotStrip
isBool()69   public boolean isBool() {
70     return TYPE_CODE_BOOL == this.mTypeCode;
71   }
72 
73   @DoNotStrip
isInt()74   public boolean isInt() {
75     return TYPE_CODE_INT == this.mTypeCode;
76   }
77 
78   @DoNotStrip
isDouble()79   public boolean isDouble() {
80     return TYPE_CODE_DOUBLE == this.mTypeCode;
81   }
82 
83   @DoNotStrip
isString()84   public boolean isString() {
85     return TYPE_CODE_STRING == this.mTypeCode;
86   }
87 
88   /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */
89   @DoNotStrip
optionalNone()90   public static EValue optionalNone() {
91     return new EValue(TYPE_CODE_NONE);
92   }
93 
94   /** Creates a new {@code EValue} of type {@code Tensor}. */
95   @DoNotStrip
from(Tensor tensor)96   public static EValue from(Tensor tensor) {
97     final EValue iv = new EValue(TYPE_CODE_TENSOR);
98     iv.mData = tensor;
99     return iv;
100   }
101 
102   /** Creates a new {@code EValue} of type {@code bool}. */
103   @DoNotStrip
from(boolean value)104   public static EValue from(boolean value) {
105     final EValue iv = new EValue(TYPE_CODE_BOOL);
106     iv.mData = value;
107     return iv;
108   }
109 
110   /** Creates a new {@code EValue} of type {@code int}. */
111   @DoNotStrip
from(long value)112   public static EValue from(long value) {
113     final EValue iv = new EValue(TYPE_CODE_INT);
114     iv.mData = value;
115     return iv;
116   }
117 
118   /** Creates a new {@code EValue} of type {@code double}. */
119   @DoNotStrip
from(double value)120   public static EValue from(double value) {
121     final EValue iv = new EValue(TYPE_CODE_DOUBLE);
122     iv.mData = value;
123     return iv;
124   }
125 
126   /** Creates a new {@code EValue} of type {@code str}. */
127   @DoNotStrip
from(String value)128   public static EValue from(String value) {
129     final EValue iv = new EValue(TYPE_CODE_STRING);
130     iv.mData = value;
131     return iv;
132   }
133 
134   @DoNotStrip
toTensor()135   public Tensor toTensor() {
136     preconditionType(TYPE_CODE_TENSOR, mTypeCode);
137     return (Tensor) mData;
138   }
139 
140   @DoNotStrip
toBool()141   public boolean toBool() {
142     preconditionType(TYPE_CODE_BOOL, mTypeCode);
143     return (boolean) mData;
144   }
145 
146   @DoNotStrip
toInt()147   public long toInt() {
148     preconditionType(TYPE_CODE_INT, mTypeCode);
149     return (long) mData;
150   }
151 
152   @DoNotStrip
toDouble()153   public double toDouble() {
154     preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
155     return (double) mData;
156   }
157 
158   @DoNotStrip
toStr()159   public String toStr() {
160     preconditionType(TYPE_CODE_STRING, mTypeCode);
161     return (String) mData;
162   }
163 
preconditionType(int typeCodeExpected, int typeCode)164   private void preconditionType(int typeCodeExpected, int typeCode) {
165     if (typeCode != typeCodeExpected) {
166       throw new IllegalStateException(
167           String.format(
168               Locale.US,
169               "Expected EValue type %s, actual type %s",
170               getTypeName(typeCodeExpected),
171               getTypeName(typeCode)));
172     }
173   }
174 
getTypeName(int typeCode)175   private String getTypeName(int typeCode) {
176     return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown";
177   }
178 
179   /**
180    * Serializes an {@code EValue} into a byte array.
181    *
182    * @return The serialized byte array.
183    * @apiNote This method is experimental and subject to change without notice.
184    */
toByteArray()185   public byte[] toByteArray() {
186     if (isNone()) {
187       return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array();
188     } else if (isTensor()) {
189       Tensor t = toTensor();
190       byte[] tByteArray = t.toByteArray();
191       return ByteBuffer.allocate(1 + tByteArray.length)
192           .put((byte) TYPE_CODE_TENSOR)
193           .put(tByteArray)
194           .array();
195     } else if (isBool()) {
196       return ByteBuffer.allocate(2)
197           .put((byte) TYPE_CODE_BOOL)
198           .put((byte) (toBool() ? 1 : 0))
199           .array();
200     } else if (isInt()) {
201       return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array();
202     } else if (isDouble()) {
203       return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
204     } else if (isString()) {
205       return ByteBuffer.allocate(1 + toString().length())
206           .put((byte) TYPE_CODE_STRING)
207           .put(toString().getBytes())
208           .array();
209     } else {
210       throw new IllegalArgumentException("Unknown Tensor dtype");
211     }
212   }
213 
214   /**
215    * Deserializes an {@code EValue} from a byte[].
216    *
217    * @param bytes The byte array to deserialize from.
218    * @return The deserialized {@code EValue}.
219    * @apiNote This method is experimental and subject to change without notice.
220    */
fromByteArray(byte[] bytes)221   public static EValue fromByteArray(byte[] bytes) {
222     ByteBuffer buffer = ByteBuffer.wrap(bytes);
223     if (buffer == null) {
224       throw new IllegalArgumentException("buffer cannot be null");
225     }
226     if (!buffer.hasRemaining()) {
227       throw new IllegalArgumentException("invalid buffer");
228     }
229     int typeCode = buffer.get();
230     switch (typeCode) {
231       case TYPE_CODE_NONE:
232         return new EValue(TYPE_CODE_NONE);
233       case TYPE_CODE_TENSOR:
234         byte[] bufferArray = buffer.array();
235         return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
236       case TYPE_CODE_STRING:
237         throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
238       case TYPE_CODE_DOUBLE:
239         return from(buffer.getDouble());
240       case TYPE_CODE_INT:
241         return from(buffer.getLong());
242       case TYPE_CODE_BOOL:
243         return from(buffer.get() != 0);
244     }
245     throw new IllegalArgumentException("invalid type code: " + typeCode);
246   }
247 }
248