Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 24, 2024
1 parent bb00fe1 commit c9e8174
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,15 @@ def test_forward_method(tmp_path):
logits, loss = llm(inputs, target_ids=inputs)
assert logits.shape == torch.Size([6, 128, 50304])
assert isinstance(loss.item(), float)


def test_precision_selection(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
init="pretrained"
)

llm.distribute(precision="16-true")
assert llm.model._forward_module.lm_head.weight.dtype == torch.float16, \
f"Expected float16, but got {llm.model._forward_module.lm_head.weight.dtype}"

0 comments on commit c9e8174

Please sign in to comment.