/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include "jni_layer_constants.h" #include #include #include #include #include #include #include #include #ifdef ET_USE_THREADPOOL #include #include #endif #include #include using namespace executorch::extension; using namespace torch::executor; namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/Tensor;"; explicit TensorHybrid(exec_aten::Tensor tensor) {} static facebook::jni::local_ref newJTensorFromTensor(const exec_aten::Tensor& tensor) { // Java wrapper currently only supports contiguous tensors. const auto scalarType = tensor.scalar_type(); if (scalar_type_to_java_dtype.count(scalarType) == 0) { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "exec_aten::Tensor scalar type %d is not supported on java side", scalarType); } int jdtype = scalar_type_to_java_dtype.at(scalarType); const auto& tensor_shape = tensor.sizes(); std::vector tensor_shape_vec; for (const auto& s : tensor_shape) { tensor_shape_vec.push_back(s); } facebook::jni::local_ref jTensorShape = facebook::jni::make_long_array(tensor_shape_vec.size()); jTensorShape->setRegion( 0, tensor_shape_vec.size(), tensor_shape_vec.data()); static auto cls = TensorHybrid::javaClassStatic(); // Note: this is safe as long as the data stored in tensor is valid; the // data won't go out of scope as long as the Method for the inference is // valid and there is no other inference call. Java layer picks up this // value immediately so the data is valid. facebook::jni::local_ref jTensorBuffer = facebook::jni::JByteBuffer::wrapBytes( (uint8_t*)tensor.data_ptr(), tensor.nbytes()); jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); static const auto jMethodNewTensor = cls->getStaticMethod( facebook::jni::alias_ref, facebook::jni::alias_ref, jint, facebook::jni::alias_ref)>("nativeNewTensor"); return jMethodNewTensor( cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); } private: friend HybridBase; }; class JEValue : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/EValue;"; constexpr static int kTypeCodeTensor = 1; constexpr static int kTypeCodeString = 2; constexpr static int kTypeCodeDouble = 3; constexpr static int kTypeCodeInt = 4; constexpr static int kTypeCodeBool = 5; static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) { if (evalue.isTensor()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod( facebook::jni::local_ref)>("from"); return jMethodTensor( JEValue::javaClassStatic(), TensorHybrid::newJTensorFromTensor(evalue.toTensor())); } else if (evalue.isInt()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jlong)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt()); } else if (evalue.isDouble()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jdouble)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble()); } else if (evalue.isBool()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod(jboolean)>( "from"); return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool()); } else if (evalue.isString()) { static auto jMethodTensor = JEValue::javaClassStatic() ->getStaticMethod( facebook::jni::local_ref)>("from"); std::string str = std::string(evalue.toString().begin(), evalue.toString().end()); return jMethodTensor( JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); } facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "Unsupported EValue type: %d", evalue.tag); } static TensorPtr JEValueToTensorImpl( facebook::jni::alias_ref JEValue) { static const auto typeCodeField = JEValue::javaClassStatic()->getField("mTypeCode"); const auto typeCode = JEValue->getFieldValue(typeCodeField); if (JEValue::kTypeCodeTensor == typeCode) { static const auto jMethodGetTensor = JEValue::javaClassStatic() ->getMethod()>( "toTensor"); auto jtensor = jMethodGetTensor(JEValue); static auto cls = TensorHybrid::javaClassStatic(); static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); jint jdtype = dtypeMethod(jtensor); static const auto shapeField = cls->getField("shape"); auto jshape = jtensor->getFieldValue(shapeField); static auto dataBufferMethod = cls->getMethod< facebook::jni::local_ref()>( "getRawDataBuffer"); facebook::jni::local_ref jbuffer = dataBufferMethod(jtensor); const auto rank = jshape->size(); const auto shapeArr = jshape->getRegion(0, rank); std::vector shape_vec; shape_vec.reserve(rank); auto numel = 1; for (int i = 0; i < rank; i++) { shape_vec.push_back(shapeArr[i]); } for (int i = rank - 1; i >= 0; --i) { numel *= shapeArr[i]; } JNIEnv* jni = facebook::jni::Environment::current(); if (java_dtype_to_scalar_type.count(jdtype) == 0) { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "Unknown Tensor jdtype %d", jdtype); } ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); if (dataCapacity != numel) { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)", numel, dataCapacity); } return from_blob( jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); } facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "Unknown EValue typeCode %d", typeCode); } }; class ExecuTorchJni : public facebook::jni::HybridClass { private: friend HybridBase; std::unique_ptr module_; public: constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/NativePeer;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, facebook::jni::alias_ref modelPath, jint loadMode) { return makeCxxInstance(modelPath, loadMode); } ExecuTorchJni(facebook::jni::alias_ref modelPath, jint loadMode) { Module::LoadMode load_mode = Module::LoadMode::Mmap; if (loadMode == 0) { load_mode = Module::LoadMode::File; } else if (loadMode == 1) { load_mode = Module::LoadMode::Mmap; } else if (loadMode == 2) { load_mode = Module::LoadMode::MmapUseMlock; } else if (loadMode == 3) { load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors; } module_ = std::make_unique(modelPath->toStdString(), load_mode); #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to // improve performant core detection in CPUInfo, but for now we can use // cores/2 as a sane default. // // Based on testing, this is almost universally faster than using all // cores, as efficiency cores can be quite slow. In extreme cases, using // all cores can be 10x slower than using cores/2. // // TODO Allow overriding this default from Java. auto threadpool = executorch::extension::threadpool::get_threadpool(); if (threadpool) { int thread_count = cpuinfo_get_processors_count() / 2; if (thread_count > 0) { threadpool->_unsafe_reset_threadpool(thread_count); } } #endif } facebook::jni::local_ref> forward( facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { return execute_method("forward", jinputs); } facebook::jni::local_ref> execute( facebook::jni::alias_ref methodName, facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { return execute_method(methodName->toStdString(), jinputs); } jint load_method(facebook::jni::alias_ref methodName) { return static_cast(module_->load_method(methodName->toStdString())); } facebook::jni::local_ref> execute_method( std::string method, facebook::jni::alias_ref< facebook::jni::JArrayClass::javaobject> jinputs) { // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { if (module_->load_method(method) != Error::Ok) { return {}; } auto&& underlying_method = module_->methods_[method].method; auto&& buf = prepare_input_tensors(*underlying_method); auto result = underlying_method->execute(); if (result != Error::Ok) { return {}; } facebook::jni::local_ref> jresult = facebook::jni::JArrayClass::newArray( underlying_method->outputs_size()); for (int i = 0; i < underlying_method->outputs_size(); i++) { auto jevalue = JEValue::newJEValueFromEValue(underlying_method->get_output(i)); jresult->setElement(i, *jevalue); } return jresult; } std::vector evalues; std::vector tensors; static const auto typeCodeField = JEValue::javaClassStatic()->getField("mTypeCode"); for (int i = 0; i < jinputs->size(); i++) { auto jevalue = jinputs->getElement(i); const auto typeCode = jevalue->getFieldValue(typeCodeField); if (typeCode == JEValue::kTypeCodeTensor) { tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); evalues.emplace_back(tensors.back()); } else if (typeCode == JEValue::kTypeCodeInt) { int64_t value = jevalue->getFieldValue(typeCodeField); evalues.emplace_back(value); } else if (typeCode == JEValue::kTypeCodeDouble) { double value = jevalue->getFieldValue(typeCodeField); evalues.emplace_back(value); } else if (typeCode == JEValue::kTypeCodeBool) { bool value = jevalue->getFieldValue(typeCodeField); evalues.emplace_back(value); } } #ifdef EXECUTORCH_ANDROID_PROFILING auto start = std::chrono::high_resolution_clock::now(); auto result = module_->execute(method, evalues); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start) .count(); ET_LOG(Debug, "Execution time: %lld ms.", duration); #else auto result = module_->execute(method, evalues); #endif if (!result.ok()) { facebook::jni::throwNewJavaException( "java/lang/Exception", "Execution of method %s failed with status 0x%" PRIx32, method.c_str(), static_cast(result.error())); return {}; } facebook::jni::local_ref> jresult = facebook::jni::JArrayClass::newArray(result.get().size()); for (int i = 0; i < result.get().size(); i++) { auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); jresult->setElement(i, *jevalue); } return jresult; } facebook::jni::local_ref> readLogBuffer() { #ifdef __ANDROID__ facebook::jni::local_ref> ret; access_log_buffer([&](std::vector& buffer) { const auto size = buffer.size(); ret = facebook::jni::JArrayClass::newArray(size); for (auto i = 0u; i < size; i++) { const auto& entry = buffer[i]; // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL // MESSAGE". std::stringstream ss; ss << "[" << entry.timestamp << " " << entry.function << " " << entry.filename << ":" << entry.line << "] " << static_cast(entry.level) << " " << entry.message; facebook::jni::local_ref jstr_message = facebook::jni::make_jstring(ss.str().c_str()); (*ret)[i] = jstr_message; } }); return ret; #else return facebook::jni::JArrayClass::newArray(0); #endif } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), }); } }; } // namespace executorch::extension #ifdef EXECUTORCH_BUILD_LLAMA_JNI extern void register_natives_for_llama(); #else // No op if we don't build llama void register_natives_for_llama() {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); register_natives_for_llama(); }); }