• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.executorch;
10 
11 import com.facebook.soloader.nativeloader.NativeLoader;
12 import com.facebook.soloader.nativeloader.SystemDelegate;
13 import org.pytorch.executorch.annotations.Experimental;
14 
15 /**
16  * Java wrapper for ExecuTorch Module.
17  *
18  * <p>Warning: These APIs are experimental and subject to change without notice
19  */
20 @Experimental
21 public class Module {
22 
23   /** Load mode for the module. Load the whole file as a buffer. */
24   public static final int LOAD_MODE_FILE = 0;
25 
26   /** Load mode for the module. Use mmap to load pages into memory. */
27   public static final int LOAD_MODE_MMAP = 1;
28 
29   /** Load mode for the module. Use memory locking and handle errors. */
30   public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
31 
32   /** Load mode for the module. Use memory locking and ignore errors. */
33   public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
34 
35   /** Reference to the NativePeer object of this module. */
36   private NativePeer mNativePeer;
37 
38   /**
39    * Loads a serialized ExecuTorch module from the specified path on the disk.
40    *
41    * @param modelPath path to file that contains the serialized ExecuTorch module.
42    * @param loadMode load mode for the module. See constants in {@link Module}.
43    * @return new {@link org.pytorch.executorch.Module} object which owns the model module.
44    */
load(final String modelPath, int loadMode)45   public static Module load(final String modelPath, int loadMode) {
46     if (!NativeLoader.isInitialized()) {
47       NativeLoader.init(new SystemDelegate());
48     }
49     return new Module(new NativePeer(modelPath, loadMode));
50   }
51 
52   /**
53    * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU.
54    *
55    * @param modelPath path to file that contains the serialized ExecuTorch module.
56    * @return new {@link org.pytorch.executorch.Module} object which owns the model module.
57    */
load(final String modelPath)58   public static Module load(final String modelPath) {
59     return load(modelPath, LOAD_MODE_FILE);
60   }
61 
Module(NativePeer nativePeer)62   Module(NativePeer nativePeer) {
63     this.mNativePeer = nativePeer;
64   }
65 
66   /**
67    * Runs the 'forward' method of this module with the specified arguments.
68    *
69    * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
70    *     requires inputs but no inputs are given, the function will not error out, but run 'forward'
71    *     with sample inputs.
72    * @return return value from the 'forward' method.
73    */
forward(EValue... inputs)74   public EValue[] forward(EValue... inputs) {
75     return mNativePeer.forward(inputs);
76   }
77 
78   /**
79    * Runs the specified method of this module with the specified arguments.
80    *
81    * @param methodName name of the ExecuTorch method to run.
82    * @param inputs arguments that will be passed to ExecuTorch method.
83    * @return return value from the method.
84    */
execute(String methodName, EValue... inputs)85   public EValue[] execute(String methodName, EValue... inputs) {
86     return mNativePeer.execute(methodName, inputs);
87   }
88 
89   /**
90    * Load a method on this module. This might help with the first time inference performance,
91    * because otherwise the method is loaded lazily when it's execute. Note: this function is
92    * synchronous, and will block until the method is loaded. Therefore, it is recommended to call
93    * this on a background thread. However, users need to make sure that they don't execute before
94    * this function returns.
95    *
96    * @return the Error code if there was an error loading the method
97    */
loadMethod(String methodName)98   public int loadMethod(String methodName) {
99     return mNativePeer.loadMethod(methodName);
100   }
101 
102   /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
readLogBuffer()103   public String[] readLogBuffer() {
104     return mNativePeer.readLogBuffer();
105   }
106 
107   /**
108    * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the
109    * native object will be destroyed when this object is garbage-collected. However, the timing of
110    * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
111    * more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
112    */
destroy()113   public void destroy() {
114     mNativePeer.resetNative();
115   }
116 }
117