From c9e8174162243104f03cb09f4f596a3f1a07a6e6 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 24 Oct 2024 07:39:21 -0500 Subject: [PATCH] add unit test --- tests/test_api.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_api.py b/tests/test_api.py index 0064ae5400..9b35fe57fb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -399,3 +399,15 @@ def test_forward_method(tmp_path): logits, loss = llm(inputs, target_ids=inputs) assert logits.shape == torch.Size([6, 128, 50304]) assert isinstance(loss.item(), float) + + +def test_precision_selection(tmp_path): + with patch("torch.backends.mps.is_available", return_value=USE_MPS): + llm = LLM.load( + model="EleutherAI/pythia-14m", + init="pretrained" + ) + + llm.distribute(precision="16-true") + assert llm.model._forward_module.lm_head.weight.dtype == torch.float16, \ + f"Expected float16, but got {llm.model._forward_module.lm_head.weight.dtype}"