Skip to content

Commit

Permalink
lint: fix more fmt errors
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <[email protected]>
  • Loading branch information
aluu317 committed Sep 17, 2024
1 parent fa97871 commit 4c9bb95
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tests/utils/test_merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
"""Unit Tests for SFT Trainer's merge_model_utils functions
"""

# Standard
import os
import tempfile

# Third Party
from safetensors import safe_open
import tempfile
import torch
import os
import pytest
import torch

# Local
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens

dir_path = os.path.dirname(os.path.realpath(__file__))
DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS = os.path.join(dir_path,
"../artifacts/tuned_llama_with_added_tokens")
DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS = os.path.join(
dir_path, "../artifacts/tuned_llama_with_added_tokens"
)


@pytest.mark.skipif(
not (torch.cuda.is_available()),
Expand All @@ -40,16 +44,20 @@ def test_post_process_vLLM_adapters_new_tokens():
"""
# first, double check dummy tuned llama has a lm_head.weight
found_lm_head = False
with safe_open(os.path.join(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, "adapter_model.safetensors"),
framework="pt") as f:
with safe_open(
os.path.join(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, "adapter_model.safetensors"),
framework="pt",
) as f:
for k in f.keys():
if "lm_head.weight" in k:
found_lm_head = True
assert found_lm_head

# do the post processing
with tempfile.TemporaryDirectory() as tempdir:
post_process_vLLM_adapters_new_tokens(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir)
post_process_vLLM_adapters_new_tokens(
DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir
)

# check that new_embeddings.safetensors exist
new_embeddings = os.path.join(tempdir, "new_embeddings.safetensors")
Expand Down

0 comments on commit 4c9bb95

Please sign in to comment.