Skip to content

Commit

Permalink
Forward arguments from TGI launcher to the model (#28)
Browse files Browse the repository at this point in the history
* Include revision

* Expose match_batch_size as envvar for TGI entrypoint

* Remove Intellij files from git

* Remove unused variable in entrypoint

* again

* Fix TGI_MAX_INPUT_LENGTH to TGI_MAX_INPUT_TOKENS to stay in tokens

* Let's allow to use specific TGI commit

* Delete comments

* Makes it possible to install specific commit of TGI also in tgi_test

* Oops missing one file

* leverage forwarded variables from the launcher to allocate the model

* Fix invalid variable name

* Add missing find-links argument to make the dependend tests running

* Update tests with new args

* Revert using git and use curl + github archive

* let's define max-batch-prefill-tokens too

* Let's map the model_id to the value provided by

* Remove overriding TGI entrypoint
  • Loading branch information
mfuntowicz authored Apr 26, 2024
1 parent e663b13 commit c7fe483
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
run: |
source venv/bin/activate
pip install --upgrade pip
pip install .[quality]
pip install .[quality] -f https://storage.googleapis.com/libtpu-releases/index.html
- name: Check style with ruff
run: |
source venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
- name: Setup environment
run: |
pip install -U pip
pip install ".[quality]"
pip install ".[quality]" -f https://storage.googleapis.com/libtpu-releases/index.html
- name: Make documentation
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc-pr-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Setup environment
run: |
pip install -U pip
pip install ".[quality]"
pip install ".[quality]" -f https://storage.googleapis.com/libtpu-releases/index.html
- name: Make documentation
shell: bash
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ dmypy.json
# Models
*.pt

.vscode
.vscode
.idea/
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))

.PHONY: build_dist style style_check clean

TGI_VERSION ?= 2.0.0
TGI_VERSION ?= 5bc3d65dd32ba1f979540caeccbf3dd8798dd9df

rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*))))

Expand Down
1 change: 1 addition & 0 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def from_pretrained(
cls = config_name_to_class(pretrained_model_name_or_path)
model = cls.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)

# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
model.config.task = task
Expand Down
12 changes: 3 additions & 9 deletions text-generation-inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM alpine AS tgi
ARG TGI_VERSION
RUN test -n ${TGI_VERSION:?}
RUN mkdir -p /tgi
ADD https://github.com/huggingface/text-generation-inference/archive/refs/tags/v${TGI_VERSION}.tar.gz /tgi/sources.tar.gz
ADD https://github.com/huggingface/text-generation-inference/archive/${TGI_VERSION}.tar.gz /tgi/sources.tar.gz
RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
Expand Down Expand Up @@ -120,15 +120,9 @@ COPY --from=pyserver /pyserver/build/dist dist
RUN pip install dist/text_generation_server*.tar.gz

# TPU compatible image
FROM tpu_base as tpu_entrypoint
FROM tpu_base

COPY text-generation-inference/entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Final image
FROM tpu_base

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
ENTRYPOINT ["./entrypoint.sh"]
43 changes: 35 additions & 8 deletions text-generation-inference/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
#!/bin/bash

if [[ -z "${HF_MODEL_ID}" ]]; then
echo "HF_MODEL_ID must be set"
# Hugging Face Hub related
if [[ -z "${MODEL_ID}" ]]; then
echo "MODEL_ID must be set"
exit 1
fi
export MODEL_ID="${HF_MODEL_ID}"
export MODEL_ID="${MODEL_ID}"

if [[ -n "${HF_MODEL_REVISION}" ]]; then
export REVISION="${HF_MODEL_REVISION}"
# TGI related
if [[ -n "${TGI_MAX_CONCURRENT_REQUESTS}" ]]; then
export TGI_MAX_CONCURRENT_REQUESTS="${TGI_MAX_CONCURRENT_REQUESTS}"
else
export TGI_MAX_CONCURRENT_REQUESTS 4
fi

if [[ -n "${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then
export TRUST_REMOTE_CODE="${HF_MODEL_TRUST_REMOTE_CODE}"
if [[ -n "${TGI_MAX_BATCH_SIZE}" ]]; then
export TGI_MAX_BATCH_SIZE="${TGI_MAX_BATCH_SIZE}"
else
export TGI_MAX_BATCH_SIZE 1
fi

text-generation-launcher --port 8080
if [[ -n "${TGI_MAX_INPUT_TOKENS}" ]]; then
export TGI_MAX_INPUT_TOKENS="${TGI_MAX_INPUT_TOKENS}"
else
export TGI_MAX_INPUT_TOKENS 128
fi

if [[ -n "${TGI_MAX_TOTAL_TOKENS}" ]]; then
export TGI_MAX_TOTAL_TOKENS="${TGI_MAX_TOTAL_TOKENS}"
else
export TGI_MAX_TOTAL_TOKENS 256
fi

TGI_MAX_BATCH_PREFILL_TOKENS=$(( TGI_MAX_BATCH_SIZE*TGI_MAX_INPUT_TOKENS ))

text-generation-launcher --port 8080 \
--max-concurrent-requests ${TGI_MAX_CONCURRENT_REQUESTS} \
--max-batch-size ${TGI_MAX_BATCH_SIZE} \
--max-batch-prefill-tokens ${TGI_MAX_BATCH_PREFILL_TOKENS} \
--max-input-tokens ${TGI_MAX_INPUT_TOKENS} \
--max-total-tokens ${TGI_MAX_TOTAL_TOKENS} \
--model-id ${MODEL_ID}

6 changes: 3 additions & 3 deletions text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pkg_name := text_generation_server
BUILDDIR ?= $(CURDIR)/build
VERSION ?= 0.0.1
TGI_VERSION ?= 1.4.2
TGI_VERSION ?= v2.0.1
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
pkg_dir := $(BUILDDIR)/$(pkg_name)
Expand Down Expand Up @@ -39,7 +39,7 @@ endif

$(BUILDDIR)/tgi/proto/%.proto:
install -d $(BUILDDIR)/tgi
curl -L https://github.com/huggingface/text-generation-inference/archive/refs/tags/v${TGI_VERSION}.tar.gz --output $(BUILDDIR)/tgi/sources.tar.gz
curl -L https://github.com/huggingface/text-generation-inference/archive/${TGI_VERSION}.tar.gz --output $(BUILDDIR)/tgi/sources.tar.gz
tar -C $(BUILDDIR)/tgi -xf $(BUILDDIR)/tgi/sources.tar.gz --strip-components=1

# Three python files are generated for each protobuf
Expand All @@ -57,4 +57,4 @@ $(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PR
sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py

gen-server: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources)
python -m build $(BUILDDIR)
python -m build $(BUILDDIR)
14 changes: 13 additions & 1 deletion text-generation-inference/server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
from typing import Optional

Expand Down Expand Up @@ -58,8 +59,19 @@ def serve(
from optimum.tpu.model import fetch_model
from .server import serve

# Read environment variables forwarded by the launcher
max_batch_size = int(os.environ.get("MAX_BATCH_SIZE", "1"))
max_total_tokens = int(os.environ.get("MAX_TOTAL_TOKENS", "64"))

# Start the server
model_path = fetch_model(model_id, revision)
serve(model_path, uds_path)
serve(
model_path,
revision=revision,
max_batch_size=max_batch_size,
max_sequence_length=max_total_tokens,
uds_path=uds_path
)


@app.command()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ def _clear(self, request_ids: List):
def from_pretrained(
cls,
model_path: str,
revision: str,
max_batch_size: int,
max_sequence_length: int
):
"""Instantiate a TpuGenerator.
Expand All @@ -633,7 +636,12 @@ def from_pretrained(
"""
logger.info("Loading model (this can take a few minutes).")
start = time.time()
model = AutoModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
revision=revision,
batch_size=max_batch_size,
sequence_length=max_sequence_length
)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ async def Decode(self, request, context):

def serve(
model_path: str,
revision: str,
max_batch_size: int,
max_sequence_length: int,
uds_path: Path,
):
async def serve_inner(model_path: str):
Expand All @@ -58,7 +61,12 @@ async def serve_inner(model_path: str):
server_urls = [local_url]

try:
generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(
model_path,
revision=revision,
max_batch_size=max_batch_size,
max_sequence_length=max_sequence_length
)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_decode_single(model_path):
max_new_tokens = 20
generated_text = "\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never"

generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
Expand Down
8 changes: 4 additions & 4 deletions text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def model_path():


def test_info(model_path):
generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1)
info = generator.info
assert info.requires_padding is True
assert info.device_type == "xla"
Expand Down Expand Up @@ -81,7 +81,7 @@ def create_request(
)
@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"])
def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path):
generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH)
requests = []
max_new_tokens = 20
for i in range(batch_size):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_
ids=["greedy", "sample"],
)
def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, model_path):
generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample)
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
Expand All @@ -140,7 +140,7 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo


def test_decode_multiple(model_path):
generator = TpuGenerator.from_pretrained(model_path)
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
Expand Down

0 comments on commit c7fe483

Please sign in to comment.