diff --git a/.github/unittest/install_dependencies_nightly.sh b/.github/unittest/install_dependencies_nightly.sh index a00b1d4..b337e1d 100644 --- a/.github/unittest/install_dependencies_nightly.sh +++ b/.github/unittest/install_dependencies_nightly.sh @@ -1,10 +1,13 @@ python -m pip install --upgrade pip python -m pip install flake8 pytest pytest-cov hydra-core tqdm -# Not using nightly torch, MolScore, promptsmiles -python -m pip install torch torchvision MolScore promptsmiles +# Not using nightly torch +python -m pip install torch torchvision # python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +# Not testing these dependencies for now +# python -m pip MolScore promptsmiles + cd ../acegen-open pip install -e . pip uninstall --yes torchrl diff --git a/tests/test_mamba_model.py b/tests/test_mamba_model.py index 807937d..98598ee 100644 --- a/tests/test_mamba_model.py +++ b/tests/test_mamba_model.py @@ -8,6 +8,19 @@ ) from utils import get_default_devices +try: + from mamba_ssm.models.mixer_seq_simple import MixerModel + + mamba_available = True +except ImportError as err: + mamba_available = False + + +skip_if_mamba_not_available = pytest.mark.skipif( + not mamba_available, + reason="mamba-ssm library is not available, skipping this test", +) + def generate_valid_data_batch( vocabulary_size: int, batch_size: int, sequence_length: int @@ -21,6 +34,7 @@ def generate_valid_data_batch( return batch +@skip_if_mamba_not_available @pytest.mark.parametrize("vocabulary_size", [10]) @pytest.mark.parametrize("device", get_default_devices()) def test_mamba_actor(vocabulary_size, device, sequence_length=5, batch_size=10): @@ -52,6 +66,7 @@ def test_mamba_actor(vocabulary_size, device, sequence_length=5, batch_size=10): assert "action" in training_batch.keys() +@skip_if_mamba_not_available @pytest.mark.parametrize("vocabulary_size", [10]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("critic_value_per_action", [True, False]) @@ -95,6 +110,7 @@ def test_mamba_critic( assert "state_value" in training_batch.keys() +@skip_if_mamba_not_available @pytest.mark.parametrize("vocabulary_size", [10]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("critic_value_per_action", [True, False])