Skip to content

Commit

Permalink
Add GPU testing to Actions
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelHudgins committed Oct 22, 2024
1 parent 354c15a commit 12a2bbb
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ on:
pull_request:
branches:
- main

permissions:
contents: read # to fetch code
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-latest
Expand All @@ -17,3 +21,27 @@ jobs:
with:
python-version: '3.10'
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/[email protected]
test:
runs-on: linux-x86-g2-48-l4-4gpu
container:
image: index.docker.io/library/ubuntu@sha256:0e5e4a57c2499249aafc3b40fcd541e9a456aab7296681a3994d631587203f97 # ratchet:ubuntu:22.04
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: '3.10'
- name: Setup Released JAX and Torch
run: |
pip install torch
pip install -U "jax[cuda12]"
pip install pytest
- name: Test JAX Triton
run: |
echo "Running JAX Triton GPU Tests"
nvidia-smi
pip install .
# Need newer ml-dtypes because we install newer numpy
pip install --upgrade ml-dtypes
pytest -v --tb=short tests/

0 comments on commit 12a2bbb

Please sign in to comment.