/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(ET_USE_THREADPOOL) #include #include #endif #include #include #if defined(EXECUTORCH_BUILD_MEDIATEK) #include #endif namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; namespace { bool utf8_check_validity(const char* str, size_t length) { for (size_t i = 0; i < length; ++i) { uint8_t byte = static_cast(str[i]); if (byte >= 0x80) { // Non-ASCII byte if (i + 1 >= length) { // Incomplete sequence return false; } uint8_t next_byte = static_cast(str[i + 1]); if ((byte & 0xE0) == 0xC0 && (next_byte & 0xC0) == 0x80) { // 2-byte sequence i += 1; } else if ( (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && (i + 2 < length) && (static_cast(str[i + 2]) & 0xC0) == 0x80) { // 3-byte sequence i += 2; } else if ( (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && (i + 2 < length) && (static_cast(str[i + 2]) & 0xC0) == 0x80 && (i + 3 < length) && (static_cast(str[i + 3]) & 0xC0) == 0x80) { // 4-byte sequence i += 3; } else { return false; // Invalid sequence } } } return true; // All bytes were valid } std::string token_buffer; } // namespace namespace executorch_jni { class ExecuTorchLlamaCallbackJni : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/LlamaCallback;"; void onResult(std::string result) const { static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); static const auto method = cls->getMethod)>("onResult"); token_buffer += result; if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { ET_LOG( Info, "Current token buffer is not valid UTF-8. Waiting for more."); return; } result = token_buffer; token_buffer = ""; facebook::jni::local_ref s = facebook::jni::make_jstring(result); method(self(), s); } void onStats(const llm::Stats& result) const { static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); static const auto method = cls->getMethod("onStats"); double eval_time = (double)(result.inference_end_ms - result.prompt_eval_end_ms); float tps = result.num_generated_tokens / eval_time * result.SCALING_FACTOR_UNITS_PER_SECOND; method(self(), tps); } }; class ExecuTorchLlamaJni : public facebook::jni::HybridClass { private: friend HybridBase; int model_type_category_; std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; public: constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/LlamaModule;"; constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, jint model_type_category, facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path, jfloat temperature) { return makeCxxInstance( model_type_category, model_path, tokenizer_path, temperature); } ExecuTorchLlamaJni( jint model_type_category, facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path, jfloat temperature) { #if defined(ET_USE_THREADPOOL) // Reserve 1 thread for the main thread. uint32_t num_performant_cores = ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; if (num_performant_cores > 0) { ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); ::executorch::extension::threadpool::get_threadpool() ->_unsafe_reset_threadpool(num_performant_cores); } #endif model_type_category_ = model_type_category; if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_ = std::make_unique( model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { runner_ = std::make_unique( model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); // Interpret the model type as LLM model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif } } jint generate( facebook::jni::alias_ref image, jint width, jint height, jint channels, facebook::jni::alias_ref prompt, jint seq_len, facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { auto image_size = image->size(); std::vector images; if (image_size != 0) { std::vector image_data_jint(image_size); std::vector image_data(image_size); image->getRegion(0, image_size, image_data_jint.data()); for (int i = 0; i < image_size; i++) { image_data[i] = image_data_jint[i]; } llm::Image image_runner{image_data, width, height, channels}; images.push_back(image_runner); } multi_modal_runner_->generate( std::move(images), prompt->toStdString(), seq_len, [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& result) { callback->onStats(result); }, echo); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { runner_->generate( prompt->toStdString(), seq_len, [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& result) { callback->onStats(result); }, echo); } return 0; } // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. facebook::jni::local_ref prefill_prompt( facebook::jni::alias_ref prompt, jlong start_pos, jint bos, jint eos) { facebook::jni::local_ref tuple_result = facebook::jni::make_long_array(2); if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { tuple_result->pin()[0] = static_cast(Error::NotSupported); return tuple_result; } auto&& result = multi_modal_runner_->prefill_prompt( prompt->toStdString(), start_pos, bos, eos); tuple_result->pin()[0] = static_cast(Error::Ok); if (result.ok()) { tuple_result->pin()[1] = static_cast(start_pos); } return tuple_result; } // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. facebook::jni::local_ref prefill_images( facebook::jni::alias_ref image, jint width, jint height, jint channels, jlong start_pos) { facebook::jni::local_ref tuple_result = facebook::jni::make_long_array(2); if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { tuple_result->pin()[0] = static_cast(Error::NotSupported); return tuple_result; } auto image_size = image->size(); std::vector images; if (image_size != 0) { std::vector image_data_jint(image_size); std::vector image_data(image_size); image->getRegion(0, image_size, image_data_jint.data()); for (int i = 0; i < image_size; i++) { image_data[i] = image_data_jint[i]; } llm::Image image_runner{image_data, width, height, channels}; images.push_back(image_runner); } // TODO(hsz): make start_pos a reference and update it here jint result = static_cast( multi_modal_runner_->prefill_images(images, start_pos)); tuple_result->pin()[0] = result; tuple_result->pin()[1] = static_cast(start_pos); return tuple_result; } jint generate_from_pos( facebook::jni::alias_ref prompt, jint seq_len, jlong start_pos, facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(Error::NotSupported); } return static_cast(multi_modal_runner_->generate_from_pos( prompt->toStdString(), seq_len, start_pos, [callback](const std::string& result) { callback->onResult(result); }, [callback](const llm::Stats& stats) { callback->onStats(stats); }, echo)); } void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { runner_->stop(); } } jint load() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(multi_modal_runner_->load()); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { return static_cast(runner_->load()); } return static_cast(Error::InvalidArgument); } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid), makeNativeMethod("generate", ExecuTorchLlamaJni::generate), makeNativeMethod("stop", ExecuTorchLlamaJni::stop), makeNativeMethod("load", ExecuTorchLlamaJni::load), makeNativeMethod( "prefillImagesNative", ExecuTorchLlamaJni::prefill_images), makeNativeMethod( "prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt), makeNativeMethod( "generateFromPos", ExecuTorchLlamaJni::generate_from_pos), }); } }; } // namespace executorch_jni void register_natives_for_llama() { executorch_jni::ExecuTorchLlamaJni::registerNatives(); }