diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 1533bec3fd2..07243911fc1 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -26,7 +26,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main strategy: matrix: - tokenizer: [bpe, tiktoken] + tokenizer: [bpe] with: runner: linux.2xlarge docker-image: executorch-ubuntu-22.04-clang12-android diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 5bba039a311..ac2ac0156b1 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -60,12 +60,27 @@ build_android_native_library() { cmake --build "${CMAKE_OUT}"/examples/models/llama2 -j "${CMAKE_JOBS}" --config Release + cmake examples/models/llava \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI="$ANDROID_ABI" \ + -DANDROID_PLATFORM=android-23 \ + -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ + -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -B"${CMAKE_OUT}"/examples/models/llava + + cmake --build "${CMAKE_OUT}"/examples/models/llava -j "${CMAKE_JOBS}" --config Release + cmake extension/android \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ -DANDROID_PLATFORM=android-23 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -DEXECUTORCH_BUILD_MULTIMODAL_JNI=ON \ -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android @@ -89,6 +104,7 @@ build_aar() { # Zip all necessary files into the AAR file zip -r executorch.aar libs jni/*/libexecutorch.so AndroidManifest.xml zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so AndroidManifest.xml + zip -r executorch-multimodal.aar libs jni/*/libexecutorch_multimodal_jni.so AndroidManifest.xml popd } diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index aa5a40a875e..5cb8a1960e3 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -139,3 +139,66 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) target_link_libraries(executorch_llama_jni re2::re2) endif() endif() + +if(EXECUTORCH_BUILD_MULTIMODAL_JNI) + set(MULTIMODAL_RUNNER_PATH + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner/libmultimodal_runner.a + ) + add_library(multimodal_runner STATIC IMPORTED) + set_property( + TARGET multimodal_runner PROPERTY IMPORTED_LOCATION ${MULTIMODAL_RUNNER_PATH} + ) + + target_link_options_shared_lib(quantized_ops_lib) + + if(TARGET pthreadpool) + set(MULTIMODAL_JNI_SRCS jni/jni_layer_multimodal.cpp + ../../backends/xnnpack/threadpool/cpuinfo_utils.cpp + ) + else() + set(MULTIMODAL_JNI_SRCS jni/jni_layer_multimodal.cpp) + endif() + add_library(executorch_multimodal_jni SHARED ${MULTIMODAL_JNI_SRCS}) + if(TARGET pthreadpool) + target_compile_definitions(executorch_multimodal_jni PRIVATE ET_USE_THREADPOOL=1) + target_include_directories( + executorch_multimodal_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include + ) + target_include_directories( + executorch_multimodal_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include + ) + endif() + target_include_directories( + executorch_multimodal_jni PRIVATE ${_common_include_directories} + ) + target_link_libraries( + executorch_multimodal_jni + ${link_libraries} + multimodal_runner + custom_ops + cpublas + eigen_blas + quantized_kernels + quantized_ops_lib + ) + target_compile_options(executorch_multimodal_jni PUBLIC ${_common_compile_options}) + if(EXECUTORCH_USE_TIKTOKEN) + set(ABSL_ENABLE_INSTALL ON) + set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/abseil-cpp + ${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp + ) + add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/re2 + ${CMAKE_CURRENT_BINARY_DIR}/re2 + ) + set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) + target_link_libraries(executorch_multimodal_jni re2::re2) + endif() +endif() diff --git a/extension/android/jni/jni_layer_multimodal.cpp b/extension/android/jni/jni_layer_multimodal.cpp new file mode 100644 index 00000000000..cc7df240754 --- /dev/null +++ b/extension/android/jni/jni_layer_multimodal.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#if defined(ET_USE_THREADPOOL) +#include +#include +#endif + +#include +#include + +#ifdef __ANDROID__ +#include + +// For Android, write to logcat +void et_pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { + int android_log_level = ANDROID_LOG_UNKNOWN; + if (level == 'D') { + android_log_level = ANDROID_LOG_DEBUG; + } else if (level == 'I') { + android_log_level = ANDROID_LOG_INFO; + } else if (level == 'E') { + android_log_level = ANDROID_LOG_ERROR; + } else if (level == 'F') { + android_log_level = ANDROID_LOG_FATAL; + } + + __android_log_print(android_log_level, "MULTIMODAL", "%s", message); +} +#endif + +using namespace torch::executor; + +namespace executorch_jni { + +class ExecuTorchMultiModalCallbackJni + : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/MultiModalCallback;"; + + void onResult(std::string result) const { + static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic(); + static const auto method = + cls->getMethod)>("onResult"); + facebook::jni::local_ref s = facebook::jni::make_jstring(result); + method(self(), s); + } + + void onStats(const MultiModalRunner::Stats& result) const { + static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic(); + static const auto method = cls->getMethod("onStats"); + double eval_time = + (double)(result.inference_end_ms - result.prompt_eval_end_ms); + + float tps = result.num_generated_tokens / eval_time * + result.SCALING_FACTOR_UNITS_PER_SECOND; + + method(self(), tps); + } +}; + +class ExecuTorchMultiModalJni + : public facebook::jni::HybridClass { + private: + friend HybridBase; + std::unique_ptr runner_; + + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/MultiModalModule;"; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path, + jfloat temperature) { + return makeCxxInstance(model_path, tokenizer_path, temperature); + } + + ExecuTorchMultiModalJni( + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path, + jfloat temperature) { +#if defined(ET_USE_THREADPOOL) + // Reserve 1 thread for the main thread. + uint32_t num_performant_cores = + torch::executorch::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( + num_performant_cores); + } +#endif + + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + temperature); + } + + jint generate( + facebook::jni::alias_ref image, + jint width, + jint height, + jint channels, + facebook::jni::alias_ref prompt, + jint startPos, + facebook::jni::alias_ref callback) { + auto image_size = image->size(); + std::vector images; + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + image->getRegion(0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jint[i]; + } + Image image_runner{image_data, width, height, channels}; + images.push_back(image_runner); + } + runner_->generate( + images, + prompt->toStdString(), + 1024, + [callback](std::string result) { callback->onResult(result); }, + [callback](const MultiModalRunner::Stats& result) { + callback->onStats(result); + }); + return 0; + } + + void stop() { + runner_->stop(); + } + + jint load() { + return static_cast(runner_->load()); + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchMultiModalJni::initHybrid), + makeNativeMethod("generate", ExecuTorchMultiModalJni::generate), + makeNativeMethod("stop", ExecuTorchMultiModalJni::stop), + makeNativeMethod("load", ExecuTorchMultiModalJni::load), + }); + } +}; + +} // namespace executorch_jni + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { + return facebook::jni::initialize( + vm, [] { executorch_jni::ExecuTorchMultiModalJni::registerNatives(); }); +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/MultiModalCallback.java b/extension/android/src/main/java/org/pytorch/executorch/MultiModalCallback.java new file mode 100644 index 00000000000..95ba762b2de --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/MultiModalCallback.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.annotations.DoNotStrip; + +public interface MultiModalCallback { + /** + * Called when a new result is available from JNI. Users will keep getting onResult() invocations + * until generate() finishes. + * + * @param result Last generated token + */ + @DoNotStrip + public void onResult(String result); + + /** + * Called when the statistics for the generate() is available. + * + * @param tps Tokens/second for generated tokens. + */ + @DoNotStrip + public void onStats(float tps); +} diff --git a/extension/android/src/main/java/org/pytorch/executorch/MultiModalModule.java b/extension/android/src/main/java/org/pytorch/executorch/MultiModalModule.java new file mode 100644 index 00000000000..69f3f6b2291 --- /dev/null +++ b/extension/android/src/main/java/org/pytorch/executorch/MultiModalModule.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; + +public class MultiModalModule { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + NativeLoader.loadLibrary("executorch_multimodal_jni"); + } + + private final HybridData mHybridData; + + @DoNotStrip + private static native HybridData initHybrid( + String modulePath, String tokenizerPath, float temperature); + + /** Constructs a MultiModal Module for a model with given path, tokenizer, and temperature. */ + public MultiModalModule(String modulePath, String tokenizerPath, float temperature) { + mHybridData = initHybrid(modulePath, tokenizerPath, temperature); + } + + public void resetNative() { + mHybridData.resetNative(); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param MultiModalCallback callback object to receive results. + */ + @DoNotStrip + public native int generate(int[] image, int width, int height, int channels, String prompt, int startPos, MultiModalCallback MultiModalCallback); + + /** Stop current generate() before it finishes. */ + @DoNotStrip + public native void stop(); + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + @DoNotStrip + public native int load(); +}