1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package org.pytorch.executorch; 10 11 import com.facebook.soloader.nativeloader.NativeLoader; 12 import com.facebook.soloader.nativeloader.SystemDelegate; 13 import org.pytorch.executorch.annotations.Experimental; 14 15 /** 16 * Java wrapper for ExecuTorch Module. 17 * 18 * <p>Warning: These APIs are experimental and subject to change without notice 19 */ 20 @Experimental 21 public class Module { 22 23 /** Load mode for the module. Load the whole file as a buffer. */ 24 public static final int LOAD_MODE_FILE = 0; 25 26 /** Load mode for the module. Use mmap to load pages into memory. */ 27 public static final int LOAD_MODE_MMAP = 1; 28 29 /** Load mode for the module. Use memory locking and handle errors. */ 30 public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; 31 32 /** Load mode for the module. Use memory locking and ignore errors. */ 33 public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; 34 35 /** Reference to the NativePeer object of this module. */ 36 private NativePeer mNativePeer; 37 38 /** 39 * Loads a serialized ExecuTorch module from the specified path on the disk. 40 * 41 * @param modelPath path to file that contains the serialized ExecuTorch module. 42 * @param loadMode load mode for the module. See constants in {@link Module}. 43 * @return new {@link org.pytorch.executorch.Module} object which owns the model module. 44 */ load(final String modelPath, int loadMode)45 public static Module load(final String modelPath, int loadMode) { 46 if (!NativeLoader.isInitialized()) { 47 NativeLoader.init(new SystemDelegate()); 48 } 49 return new Module(new NativePeer(modelPath, loadMode)); 50 } 51 52 /** 53 * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. 54 * 55 * @param modelPath path to file that contains the serialized ExecuTorch module. 56 * @return new {@link org.pytorch.executorch.Module} object which owns the model module. 57 */ load(final String modelPath)58 public static Module load(final String modelPath) { 59 return load(modelPath, LOAD_MODE_FILE); 60 } 61 Module(NativePeer nativePeer)62 Module(NativePeer nativePeer) { 63 this.mNativePeer = nativePeer; 64 } 65 66 /** 67 * Runs the 'forward' method of this module with the specified arguments. 68 * 69 * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' 70 * requires inputs but no inputs are given, the function will not error out, but run 'forward' 71 * with sample inputs. 72 * @return return value from the 'forward' method. 73 */ forward(EValue... inputs)74 public EValue[] forward(EValue... inputs) { 75 return mNativePeer.forward(inputs); 76 } 77 78 /** 79 * Runs the specified method of this module with the specified arguments. 80 * 81 * @param methodName name of the ExecuTorch method to run. 82 * @param inputs arguments that will be passed to ExecuTorch method. 83 * @return return value from the method. 84 */ execute(String methodName, EValue... inputs)85 public EValue[] execute(String methodName, EValue... inputs) { 86 return mNativePeer.execute(methodName, inputs); 87 } 88 89 /** 90 * Load a method on this module. This might help with the first time inference performance, 91 * because otherwise the method is loaded lazily when it's execute. Note: this function is 92 * synchronous, and will block until the method is loaded. Therefore, it is recommended to call 93 * this on a background thread. However, users need to make sure that they don't execute before 94 * this function returns. 95 * 96 * @return the Error code if there was an error loading the method 97 */ loadMethod(String methodName)98 public int loadMethod(String methodName) { 99 return mNativePeer.loadMethod(methodName); 100 } 101 102 /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ readLogBuffer()103 public String[] readLogBuffer() { 104 return mNativePeer.readLogBuffer(); 105 } 106 107 /** 108 * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the 109 * native object will be destroyed when this object is garbage-collected. However, the timing of 110 * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory 111 * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. 112 */ destroy()113 public void destroy() { 114 mNativePeer.resetNative(); 115 } 116 } 117