Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Support loading PEFT (LoRA) models
Browse files Browse the repository at this point in the history
  • Loading branch information
idoru committed Jun 4, 2023
1 parent a368310 commit 008f8f0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions basaran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def is_true(value):
PORT = int(os.getenv("PORT", "80"))

# Model-related arguments:
MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", ""))
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
Expand Down
2 changes: 2 additions & 0 deletions basaran/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import MODEL_LOAD_IN_4BIT
from . import MODEL_4BIT_QUANT_TYPE
from . import MODEL_4BIT_DOUBLE_QUANT
from . import MODEL_PEFT
from . import MODEL_LOCAL_FILES_ONLY
from . import MODEL_TRUST_REMOTE_CODE
from . import MODEL_HALF_PRECISION
Expand All @@ -44,6 +45,7 @@
name_or_path=MODEL,
revision=MODEL_REVISION,
cache_dir=MODEL_CACHE_DIR,
is_peft=MODEL_PEFT,
load_in_8bit=MODEL_LOAD_IN_8BIT,
load_in_4bit=MODEL_LOAD_IN_4BIT,
quant_type=MODEL_4BIT_QUANT_TYPE,
Expand Down
9 changes: 9 additions & 0 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
TopPLogitsWarper,
BitsAndBytesConfig
)
from peft import (
PeftConfig,
PeftModel
)

from .choice import map_choice
from .tokenizer import StreamTokenizer
Expand Down Expand Up @@ -302,6 +306,7 @@ def load_model(
name_or_path,
revision=None,
cache_dir=None,
is_peft=False,
load_in_8bit=False,
load_in_4bit=False,
quant_type="fp4",
Expand Down Expand Up @@ -346,6 +351,10 @@ def load_model(
if half_precision or load_in_8bit or load_in_4bit:
kwargs["torch_dtype"] = torch.float16

if is_peft:
peft_config = PeftConfig.from_pretrained(name_or_path)
name_or_path = peft_config.base_model_name_or_path

# Support both decoder-only and encoder-decoder models.
try:
model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ safetensors~=0.3.1
torch>=1.12.1
transformers[sentencepiece]~=4.29.2
waitress~=2.1.2
peft~=0.3.0

0 comments on commit 008f8f0

Please sign in to comment.