• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package org.pytorch.executorch;
10 
11 import com.facebook.jni.HybridData;
12 import com.facebook.jni.annotations.DoNotStrip;
13 import com.facebook.soloader.nativeloader.NativeLoader;
14 import com.facebook.soloader.nativeloader.SystemDelegate;
15 import org.pytorch.executorch.annotations.Experimental;
16 
17 /**
18  * LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to
19  * generate text from the model.
20  *
21  * <p>Warning: These APIs are experimental and subject to change without notice
22  */
23 @Experimental
24 public class LlamaModule {
25 
26   public static final int MODEL_TYPE_TEXT = 1;
27   public static final int MODEL_TYPE_TEXT_VISION = 2;
28 
29   static {
30     if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate())31       NativeLoader.init(new SystemDelegate());
32     }
33     NativeLoader.loadLibrary("executorch");
34   }
35 
36   private final HybridData mHybridData;
37   private static final int DEFAULT_SEQ_LEN = 128;
38   private static final boolean DEFAULT_ECHO = true;
39 
40   @DoNotStrip
initHybrid( int modelType, String modulePath, String tokenizerPath, float temperature)41   private static native HybridData initHybrid(
42       int modelType, String modulePath, String tokenizerPath, float temperature);
43 
44   /** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
LlamaModule(String modulePath, String tokenizerPath, float temperature)45   public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
46     mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature);
47   }
48 
49   /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature)50   public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
51     mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature);
52   }
53 
resetNative()54   public void resetNative() {
55     mHybridData.resetNative();
56   }
57 
58   /**
59    * Start generating tokens from the module.
60    *
61    * @param prompt Input prompt
62    * @param llamaCallback callback object to receive results.
63    */
generate(String prompt, LlamaCallback llamaCallback)64   public int generate(String prompt, LlamaCallback llamaCallback) {
65     return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
66   }
67 
68   /**
69    * Start generating tokens from the module.
70    *
71    * @param prompt Input prompt
72    * @param seqLen sequence length
73    * @param llamaCallback callback object to receive results.
74    */
generate(String prompt, int seqLen, LlamaCallback llamaCallback)75   public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
76     return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO);
77   }
78 
79   /**
80    * Start generating tokens from the module.
81    *
82    * @param prompt Input prompt
83    * @param llamaCallback callback object to receive results
84    * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
85    */
generate(String prompt, LlamaCallback llamaCallback, boolean echo)86   public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) {
87     return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo);
88   }
89 
90   /**
91    * Start generating tokens from the module.
92    *
93    * @param prompt Input prompt
94    * @param seqLen sequence length
95    * @param llamaCallback callback object to receive results
96    * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
97    */
generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo)98   public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) {
99     return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo);
100   }
101 
102   /**
103    * Start generating tokens from the module.
104    *
105    * @param image Input image as a byte array
106    * @param width Input image width
107    * @param height Input image height
108    * @param channels Input image number of channels
109    * @param prompt Input prompt
110    * @param seqLen sequence length
111    * @param llamaCallback callback object to receive results.
112    * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
113    */
114   @DoNotStrip
generate( int[] image, int width, int height, int channels, String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo)115   public native int generate(
116       int[] image,
117       int width,
118       int height,
119       int channels,
120       String prompt,
121       int seqLen,
122       LlamaCallback llamaCallback,
123       boolean echo);
124 
125   /**
126    * Prefill an LLaVA Module with the given images input.
127    *
128    * @param image Input image as a byte array
129    * @param width Input image width
130    * @param height Input image height
131    * @param channels Input image number of channels
132    * @param startPos The starting position in KV cache of the input in the LLM.
133    * @return The updated starting position in KV cache of the input in the LLM.
134    * @throws RuntimeException if the prefill failed
135    */
prefillImages(int[] image, int width, int height, int channels, long startPos)136   public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
137     long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
138     if (nativeResult[0] != 0) {
139       throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
140     }
141     return nativeResult[1];
142   }
143 
144   // returns a tuple of (status, updated startPos)
prefillImagesNative( int[] image, int width, int height, int channels, long startPos)145   private native long[] prefillImagesNative(
146       int[] image, int width, int height, int channels, long startPos);
147 
148   /**
149    * Prefill an LLaVA Module with the given text input.
150    *
151    * @param prompt The text prompt to LLaVA.
152    * @param startPos The starting position in KV cache of the input in the LLM. It's passed as
153    *     reference and will be updated inside this function.
154    * @param bos The number of BOS (begin of sequence) token.
155    * @param eos The number of EOS (end of sequence) token.
156    * @return The updated starting position in KV cache of the input in the LLM.
157    * @throws RuntimeException if the prefill failed
158    */
prefillPrompt(String prompt, long startPos, int bos, int eos)159   public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
160     long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
161     if (nativeResult[0] != 0) {
162       throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
163     }
164     return nativeResult[1];
165   }
166 
167   // returns a tuple of (status, updated startPos)
prefillPromptNative(String prompt, long startPos, int bos, int eos)168   private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
169 
170   /**
171    * Generate tokens from the given prompt, starting from the given position.
172    *
173    * @param prompt The text prompt to LLaVA.
174    * @param seqLen The total sequence length, including the prompt tokens and new tokens.
175    * @param startPos The starting position in KV cache of the input in the LLM.
176    * @param callback callback object to receive results.
177    * @param echo indicate whether to echo the input prompt or not.
178    * @return The error code.
179    */
generateFromPos( String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo)180   public native int generateFromPos(
181       String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);
182 
183   /** Stop current generate() before it finishes. */
184   @DoNotStrip
stop()185   public native void stop();
186 
187   /** Force loading the module. Otherwise the model is loaded during first generate(). */
188   @DoNotStrip
load()189   public native int load();
190 }
191