Skip to content

Commit

Permalink
[llava] Use huggingface LLaVA instead of depending on third-party/LLaVa
Browse files Browse the repository at this point in the history
Differential Revision: D61200610

Pull Request resolved: #4687
  • Loading branch information
larryliu0820 authored Aug 14, 2024
1 parent 49c6a10 commit 5d151d0
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 208 deletions.
20 changes: 3 additions & 17 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,13 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
bash install_requirements.sh --pybind xnnpack
# install Llava requirements
bash examples/models/llama2/install_requirements.sh
bash examples/models/llava/install_requirements.sh
# run export_llava.sh
python examples/models/llava/export_llava.py --use-sdpa-with-kv-cache --pte-name llava_custom_sdpa.pte
# verify file exists
if [ ! -f "llava_custom_sdpa.pte" ]; then
echo "llava_custom_sdpa.pte not found!"
exit 1
fi
python examples/models/llava/export_llava.py --no-use-sdpa-with-kv-cache --pte-name llava.pte
# verify file exists
if [ ! -f "llava.pte" ]; then
echo "llava.pte not found!"
exit 1
fi
# run python unittest
python -m unittest examples.models.llava.test.test_llava
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
[submodule "backends/xnnpack/third-party/pthreadpool"]
path = backends/xnnpack/third-party/pthreadpool
url = https://github.com/Maratyszcza/pthreadpool.git
[submodule "examples/third-party/LLaVA"]
path = examples/third-party/LLaVA
url = https://github.com/haotian-liu/LLaVA.git
[submodule "examples/third-party/fbjni"]
path = examples/third-party/fbjni
url = https://github.com/facebookincubator/fbjni.git
Expand Down
56 changes: 27 additions & 29 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from executorch.examples.models.llama2.source_transformation.sdpa import (
replace_sdpa_with_custom_op,
)
from executorch.examples.models.llava.model import LlavaModel
from executorch.exir import EdgeCompileConfig
from executorch.exir.program._program import _to_edge_transform_and_lower

from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from model import LlavaModel
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
Expand Down Expand Up @@ -85,7 +85,7 @@ def forward(self, input_pos, embeddings):
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
)
quant_transform = get_quant_weight_transform(args, dtype_override, False)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
_, quantizers, _ = get_quantizer_and_quant_params(args)
source_transforms = []
if llava.use_sdpa_with_kv_cache_op:
source_transforms.append(replace_sdpa_with_custom_op)
Expand Down Expand Up @@ -149,15 +149,7 @@ def forward(self, images):


def export_token_embedding(llava, prompt):
embed = torch.nn.Embedding(
llava.model_.config.vocab_size,
llava.model_.config.hidden_size,
llava.model_.config.pad_token_id,
)
embed.load_state_dict(
llava.model_.get_model().embed_tokens.state_dict(), strict=True, assign=True
)
embed = embed.to(torch.float32)
embed = llava.embed_tokens
token_dim_1 = Dim("token_dim_1", min=2, max=3518)
dynamic_shapes = [{1: token_dim_1}]
with torch.no_grad():
Expand All @@ -167,24 +159,7 @@ def export_token_embedding(llava, prompt):
return token_embedding_ep


def main():
parser = ArgumentParser()
parser.add_argument(
"--use-sdpa-with-kv-cache",
default=True,
action=BooleanOptionalAction,
help="Use sdpa_with_kv_cache custom op in LLava text model.",
)
parser.add_argument(
"--pte-name",
default="llava_combined_xnnpack.pte",
help="Name of the exported ExecuTorch program.",
)
args = parser.parse_args()
logging.info(
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}"
)
llava_model = LlavaModel(use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache)
def export_all(llava_model: LlavaModel):
llava = llava_model.get_eager_model()

(
Expand Down Expand Up @@ -226,6 +201,29 @@ def main():
)

executorch_program = lowered_and_edge.to_executorch()
return executorch_program


def main():
parser = ArgumentParser()
parser.add_argument(
"--use-sdpa-with-kv-cache",
default=True,
action=BooleanOptionalAction,
help="Use sdpa_with_kv_cache custom op in LLava text model.",
)
parser.add_argument(
"--pte-name",
default="llava_combined_xnnpack.pte",
help="Name of the exported ExecuTorch program.",
)
args = parser.parse_args()
logging.info(
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}"
)
llava_model = LlavaModel(use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache)

executorch_program = export_all(llava_model)

with open(args.pte_name, "wb") as f:
executorch_program.write_to_file(f)
Expand Down
34 changes: 1 addition & 33 deletions examples/models/llava/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,7 @@
# LICENSE file in the root directory of this source tree.

set -x
OS=$(uname)

# install llava from the submodule. We can't do pip install llava because it is packaged incorrectly.
if [[ $OS != "Darwin" ]];
then
#This doesn't work for macos, on python 3.12, because torch 2.1.2 is missing.
pip install --force-reinstall -e examples/third-party/LLaVA
else
# manually install dependencies
pip install tokenizers==0.15.1 sentencepiece==0.1.99 \
shortuuid accelerate==0.21.0 peft \
pydantic markdown2[all] scikit-learn==1.2.2 \
requests httpx==0.24.0 uvicorn fastapi \
einops==0.6.1 einops-exts==0.0.4 timm==0.6.13

pip install --force-reinstall -e examples/third-party/LLaVA --no-deps
fi

# not included in the pip install package, but needed in llava
pip install protobuf

# bitsandbytes depends on numpy 1.x, which is not compatible with numpy 2.x.
# Reinstall bitsandbytes to make it compatible.
pip install bitsandbytes -I

# The deps of llava can have different versions than deps of ExecuTorch.
# For example, torch version required from llava is older than ExecuTorch.
# To make both work, recover ExecuTorch's original dependencies by rerunning
# the install_requirements.sh. Notice this won't install executorch.
bash -x ./install_requirements.sh --pybind xnnpack

# Newer transformer (4.38) will give TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'
pip install timm==0.6.13
pip install transformers==4.37.2
pip install transformers

pip list
Loading

0 comments on commit 5d151d0

Please sign in to comment.