diff --git a/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt index 8828ccea..2b9fd336 100644 --- a/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt +++ b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt @@ -1 +1,2 @@ -transformers>=4.44.2 \ No newline at end of file +transformers>=4.44.2 +litgpt>=0.5.0 \ No newline at end of file diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py index 08f17550..7b0d4fa0 100644 --- a/ch05/07_gpt_to_llama/tests/tests.py +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -235,7 +235,6 @@ def test_rope_llama3(notebook): torch.testing.assert_close(queries_rot, litgpt_queries_rot) - def test_rope_llama3_12(notebook): nb1 = notebook["converting-gpt-to-llama2"] @@ -312,7 +311,10 @@ class RoPEConfig: } litgpt_cos, litgpt_sin = litgpt_build_rope_cache( - context_len, n_elem=head_dim, base=rope_theta, extra_config=litgpt_rope_config + context_len, + n_elem=head_dim, + base=rope_theta, + extra_config=litgpt_rope_config ) litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin) litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)