Skip to content

Commit

Permalink
ci
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 6, 2024
1 parent 082394e commit 03dc687
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
7 changes: 5 additions & 2 deletions .github/unittest/install_dependencies_nightly.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm

# Not using nightly torch, MolScore, promptsmiles
python -m pip install torch torchvision MolScore promptsmiles
# Not using nightly torch
python -m pip install torch torchvision
# python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall

# Not testing these dependencies for now
# python -m pip MolScore promptsmiles

cd ../acegen-open
pip install -e .
pip uninstall --yes torchrl
Expand Down
16 changes: 16 additions & 0 deletions tests/test_mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
)
from utils import get_default_devices

try:
from mamba_ssm.models.mixer_seq_simple import MixerModel

mamba_available = True
except ImportError as err:
mamba_available = False


skip_if_mamba_not_available = pytest.mark.skipif(
not mamba_available,
reason="mamba-ssm library is not available, skipping this test",
)


def generate_valid_data_batch(
vocabulary_size: int, batch_size: int, sequence_length: int
Expand All @@ -21,6 +34,7 @@ def generate_valid_data_batch(
return batch


@skip_if_mamba_not_available
@pytest.mark.parametrize("vocabulary_size", [10])
@pytest.mark.parametrize("device", get_default_devices())
def test_mamba_actor(vocabulary_size, device, sequence_length=5, batch_size=10):
Expand Down Expand Up @@ -52,6 +66,7 @@ def test_mamba_actor(vocabulary_size, device, sequence_length=5, batch_size=10):
assert "action" in training_batch.keys()


@skip_if_mamba_not_available
@pytest.mark.parametrize("vocabulary_size", [10])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("critic_value_per_action", [True, False])
Expand Down Expand Up @@ -95,6 +110,7 @@ def test_mamba_critic(
assert "state_value" in training_batch.keys()


@skip_if_mamba_not_available
@pytest.mark.parametrize("vocabulary_size", [10])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("critic_value_per_action", [True, False])
Expand Down

0 comments on commit 03dc687

Please sign in to comment.