1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package org.pytorch.executorch; 10 11 import com.facebook.jni.HybridData; 12 import com.facebook.jni.annotations.DoNotStrip; 13 import com.facebook.soloader.nativeloader.NativeLoader; 14 import com.facebook.soloader.nativeloader.SystemDelegate; 15 import org.pytorch.executorch.annotations.Experimental; 16 17 /** 18 * LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to 19 * generate text from the model. 20 * 21 * <p>Warning: These APIs are experimental and subject to change without notice 22 */ 23 @Experimental 24 public class LlamaModule { 25 26 public static final int MODEL_TYPE_TEXT = 1; 27 public static final int MODEL_TYPE_TEXT_VISION = 2; 28 29 static { 30 if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate())31 NativeLoader.init(new SystemDelegate()); 32 } 33 NativeLoader.loadLibrary("executorch"); 34 } 35 36 private final HybridData mHybridData; 37 private static final int DEFAULT_SEQ_LEN = 128; 38 private static final boolean DEFAULT_ECHO = true; 39 40 @DoNotStrip initHybrid( int modelType, String modulePath, String tokenizerPath, float temperature)41 private static native HybridData initHybrid( 42 int modelType, String modulePath, String tokenizerPath, float temperature); 43 44 /** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */ LlamaModule(String modulePath, String tokenizerPath, float temperature)45 public LlamaModule(String modulePath, String tokenizerPath, float temperature) { 46 mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature); 47 } 48 49 /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature)50 public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) { 51 mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature); 52 } 53 resetNative()54 public void resetNative() { 55 mHybridData.resetNative(); 56 } 57 58 /** 59 * Start generating tokens from the module. 60 * 61 * @param prompt Input prompt 62 * @param llamaCallback callback object to receive results. 63 */ generate(String prompt, LlamaCallback llamaCallback)64 public int generate(String prompt, LlamaCallback llamaCallback) { 65 return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO); 66 } 67 68 /** 69 * Start generating tokens from the module. 70 * 71 * @param prompt Input prompt 72 * @param seqLen sequence length 73 * @param llamaCallback callback object to receive results. 74 */ generate(String prompt, int seqLen, LlamaCallback llamaCallback)75 public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { 76 return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO); 77 } 78 79 /** 80 * Start generating tokens from the module. 81 * 82 * @param prompt Input prompt 83 * @param llamaCallback callback object to receive results 84 * @param echo indicate whether to echo the input prompt or not (text completion vs chat) 85 */ generate(String prompt, LlamaCallback llamaCallback, boolean echo)86 public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) { 87 return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo); 88 } 89 90 /** 91 * Start generating tokens from the module. 92 * 93 * @param prompt Input prompt 94 * @param seqLen sequence length 95 * @param llamaCallback callback object to receive results 96 * @param echo indicate whether to echo the input prompt or not (text completion vs chat) 97 */ generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo)98 public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) { 99 return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo); 100 } 101 102 /** 103 * Start generating tokens from the module. 104 * 105 * @param image Input image as a byte array 106 * @param width Input image width 107 * @param height Input image height 108 * @param channels Input image number of channels 109 * @param prompt Input prompt 110 * @param seqLen sequence length 111 * @param llamaCallback callback object to receive results. 112 * @param echo indicate whether to echo the input prompt or not (text completion vs chat) 113 */ 114 @DoNotStrip generate( int[] image, int width, int height, int channels, String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo)115 public native int generate( 116 int[] image, 117 int width, 118 int height, 119 int channels, 120 String prompt, 121 int seqLen, 122 LlamaCallback llamaCallback, 123 boolean echo); 124 125 /** 126 * Prefill an LLaVA Module with the given images input. 127 * 128 * @param image Input image as a byte array 129 * @param width Input image width 130 * @param height Input image height 131 * @param channels Input image number of channels 132 * @param startPos The starting position in KV cache of the input in the LLM. 133 * @return The updated starting position in KV cache of the input in the LLM. 134 * @throws RuntimeException if the prefill failed 135 */ prefillImages(int[] image, int width, int height, int channels, long startPos)136 public long prefillImages(int[] image, int width, int height, int channels, long startPos) { 137 long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); 138 if (nativeResult[0] != 0) { 139 throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); 140 } 141 return nativeResult[1]; 142 } 143 144 // returns a tuple of (status, updated startPos) prefillImagesNative( int[] image, int width, int height, int channels, long startPos)145 private native long[] prefillImagesNative( 146 int[] image, int width, int height, int channels, long startPos); 147 148 /** 149 * Prefill an LLaVA Module with the given text input. 150 * 151 * @param prompt The text prompt to LLaVA. 152 * @param startPos The starting position in KV cache of the input in the LLM. It's passed as 153 * reference and will be updated inside this function. 154 * @param bos The number of BOS (begin of sequence) token. 155 * @param eos The number of EOS (end of sequence) token. 156 * @return The updated starting position in KV cache of the input in the LLM. 157 * @throws RuntimeException if the prefill failed 158 */ prefillPrompt(String prompt, long startPos, int bos, int eos)159 public long prefillPrompt(String prompt, long startPos, int bos, int eos) { 160 long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); 161 if (nativeResult[0] != 0) { 162 throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); 163 } 164 return nativeResult[1]; 165 } 166 167 // returns a tuple of (status, updated startPos) prefillPromptNative(String prompt, long startPos, int bos, int eos)168 private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); 169 170 /** 171 * Generate tokens from the given prompt, starting from the given position. 172 * 173 * @param prompt The text prompt to LLaVA. 174 * @param seqLen The total sequence length, including the prompt tokens and new tokens. 175 * @param startPos The starting position in KV cache of the input in the LLM. 176 * @param callback callback object to receive results. 177 * @param echo indicate whether to echo the input prompt or not. 178 * @return The error code. 179 */ generateFromPos( String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo)180 public native int generateFromPos( 181 String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo); 182 183 /** Stop current generate() before it finishes. */ 184 @DoNotStrip stop()185 public native void stop(); 186 187 /** Force loading the module. Otherwise the model is loaded during first generate(). */ 188 @DoNotStrip load()189 public native int load(); 190 } 191