1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package org.pytorch.executorch; 10 11 import com.facebook.jni.annotations.DoNotStrip; 12 import java.nio.ByteBuffer; 13 import java.util.Arrays; 14 import java.util.Locale; 15 import org.pytorch.executorch.annotations.Experimental; 16 17 /** 18 * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one 19 * of the supported types: https://pytorch.org/docs/stable/jit.html#types . 20 * 21 * <p>Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. 22 * 23 * <p>{@code EValue} objects are constructed with {@code EValue.from(value)}, {@code 24 * EValue.tupleFrom(value1, value2, ...)}, {@code EValue.listFrom(value1, value2, ...)}, or one of 25 * the {@code dict} methods, depending on the key type. 26 * 27 * <p>Data is retrieved from {@code EValue} objects with the {@code toX()} methods. Note that {@code 28 * str}-type EValues must be extracted with {@link #toStr()}, rather than {@link #toString()}. 29 * 30 * <p>{@code EValue} objects may retain references to objects passed into their constructors, and 31 * may return references to their internal state from {@code toX()}. 32 * 33 * <p>Warning: These APIs are experimental and subject to change without notice 34 */ 35 @Experimental 36 @DoNotStrip 37 public class EValue { 38 private static final int TYPE_CODE_NONE = 0; 39 40 private static final int TYPE_CODE_TENSOR = 1; 41 private static final int TYPE_CODE_STRING = 2; 42 private static final int TYPE_CODE_DOUBLE = 3; 43 private static final int TYPE_CODE_INT = 4; 44 private static final int TYPE_CODE_BOOL = 5; 45 46 private String[] TYPE_NAMES = { 47 "None", "Tensor", "String", "Double", "Int", "Bool", 48 }; 49 50 @DoNotStrip private final int mTypeCode; 51 @DoNotStrip private Object mData; 52 53 @DoNotStrip EValue(int typeCode)54 private EValue(int typeCode) { 55 this.mTypeCode = typeCode; 56 } 57 58 @DoNotStrip isNone()59 public boolean isNone() { 60 return TYPE_CODE_NONE == this.mTypeCode; 61 } 62 63 @DoNotStrip isTensor()64 public boolean isTensor() { 65 return TYPE_CODE_TENSOR == this.mTypeCode; 66 } 67 68 @DoNotStrip isBool()69 public boolean isBool() { 70 return TYPE_CODE_BOOL == this.mTypeCode; 71 } 72 73 @DoNotStrip isInt()74 public boolean isInt() { 75 return TYPE_CODE_INT == this.mTypeCode; 76 } 77 78 @DoNotStrip isDouble()79 public boolean isDouble() { 80 return TYPE_CODE_DOUBLE == this.mTypeCode; 81 } 82 83 @DoNotStrip isString()84 public boolean isString() { 85 return TYPE_CODE_STRING == this.mTypeCode; 86 } 87 88 /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ 89 @DoNotStrip optionalNone()90 public static EValue optionalNone() { 91 return new EValue(TYPE_CODE_NONE); 92 } 93 94 /** Creates a new {@code EValue} of type {@code Tensor}. */ 95 @DoNotStrip from(Tensor tensor)96 public static EValue from(Tensor tensor) { 97 final EValue iv = new EValue(TYPE_CODE_TENSOR); 98 iv.mData = tensor; 99 return iv; 100 } 101 102 /** Creates a new {@code EValue} of type {@code bool}. */ 103 @DoNotStrip from(boolean value)104 public static EValue from(boolean value) { 105 final EValue iv = new EValue(TYPE_CODE_BOOL); 106 iv.mData = value; 107 return iv; 108 } 109 110 /** Creates a new {@code EValue} of type {@code int}. */ 111 @DoNotStrip from(long value)112 public static EValue from(long value) { 113 final EValue iv = new EValue(TYPE_CODE_INT); 114 iv.mData = value; 115 return iv; 116 } 117 118 /** Creates a new {@code EValue} of type {@code double}. */ 119 @DoNotStrip from(double value)120 public static EValue from(double value) { 121 final EValue iv = new EValue(TYPE_CODE_DOUBLE); 122 iv.mData = value; 123 return iv; 124 } 125 126 /** Creates a new {@code EValue} of type {@code str}. */ 127 @DoNotStrip from(String value)128 public static EValue from(String value) { 129 final EValue iv = new EValue(TYPE_CODE_STRING); 130 iv.mData = value; 131 return iv; 132 } 133 134 @DoNotStrip toTensor()135 public Tensor toTensor() { 136 preconditionType(TYPE_CODE_TENSOR, mTypeCode); 137 return (Tensor) mData; 138 } 139 140 @DoNotStrip toBool()141 public boolean toBool() { 142 preconditionType(TYPE_CODE_BOOL, mTypeCode); 143 return (boolean) mData; 144 } 145 146 @DoNotStrip toInt()147 public long toInt() { 148 preconditionType(TYPE_CODE_INT, mTypeCode); 149 return (long) mData; 150 } 151 152 @DoNotStrip toDouble()153 public double toDouble() { 154 preconditionType(TYPE_CODE_DOUBLE, mTypeCode); 155 return (double) mData; 156 } 157 158 @DoNotStrip toStr()159 public String toStr() { 160 preconditionType(TYPE_CODE_STRING, mTypeCode); 161 return (String) mData; 162 } 163 preconditionType(int typeCodeExpected, int typeCode)164 private void preconditionType(int typeCodeExpected, int typeCode) { 165 if (typeCode != typeCodeExpected) { 166 throw new IllegalStateException( 167 String.format( 168 Locale.US, 169 "Expected EValue type %s, actual type %s", 170 getTypeName(typeCodeExpected), 171 getTypeName(typeCode))); 172 } 173 } 174 getTypeName(int typeCode)175 private String getTypeName(int typeCode) { 176 return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; 177 } 178 179 /** 180 * Serializes an {@code EValue} into a byte array. 181 * 182 * @return The serialized byte array. 183 * @apiNote This method is experimental and subject to change without notice. 184 */ toByteArray()185 public byte[] toByteArray() { 186 if (isNone()) { 187 return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array(); 188 } else if (isTensor()) { 189 Tensor t = toTensor(); 190 byte[] tByteArray = t.toByteArray(); 191 return ByteBuffer.allocate(1 + tByteArray.length) 192 .put((byte) TYPE_CODE_TENSOR) 193 .put(tByteArray) 194 .array(); 195 } else if (isBool()) { 196 return ByteBuffer.allocate(2) 197 .put((byte) TYPE_CODE_BOOL) 198 .put((byte) (toBool() ? 1 : 0)) 199 .array(); 200 } else if (isInt()) { 201 return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array(); 202 } else if (isDouble()) { 203 return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array(); 204 } else if (isString()) { 205 return ByteBuffer.allocate(1 + toString().length()) 206 .put((byte) TYPE_CODE_STRING) 207 .put(toString().getBytes()) 208 .array(); 209 } else { 210 throw new IllegalArgumentException("Unknown Tensor dtype"); 211 } 212 } 213 214 /** 215 * Deserializes an {@code EValue} from a byte[]. 216 * 217 * @param bytes The byte array to deserialize from. 218 * @return The deserialized {@code EValue}. 219 * @apiNote This method is experimental and subject to change without notice. 220 */ fromByteArray(byte[] bytes)221 public static EValue fromByteArray(byte[] bytes) { 222 ByteBuffer buffer = ByteBuffer.wrap(bytes); 223 if (buffer == null) { 224 throw new IllegalArgumentException("buffer cannot be null"); 225 } 226 if (!buffer.hasRemaining()) { 227 throw new IllegalArgumentException("invalid buffer"); 228 } 229 int typeCode = buffer.get(); 230 switch (typeCode) { 231 case TYPE_CODE_NONE: 232 return new EValue(TYPE_CODE_NONE); 233 case TYPE_CODE_TENSOR: 234 byte[] bufferArray = buffer.array(); 235 return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length))); 236 case TYPE_CODE_STRING: 237 throw new IllegalArgumentException("TYPE_CODE_STRING is not supported"); 238 case TYPE_CODE_DOUBLE: 239 return from(buffer.getDouble()); 240 case TYPE_CODE_INT: 241 return from(buffer.getLong()); 242 case TYPE_CODE_BOOL: 243 return from(buffer.get() != 0); 244 } 245 throw new IllegalArgumentException("invalid type code: " + typeCode); 246 } 247 } 248