Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use uv to install Python packages #26511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ jobs:
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
pip install -r build/test-requirements.txt
pip install uv~=0.5.30
uv pip install -r build/test-requirements.txt
- name: Build and install JAX
env:
ASAN_OPTIONS: detect_leaks=0
Expand All @@ -73,8 +74,8 @@ jobs:
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
--clang_path=/usr/bin/clang-18
pip install dist/jaxlib-*.whl
pip install -e .
uv pip install dist/jaxlib-*.whl \
-e .
- name: Run tests
env:
ASAN_OPTIONS: detect_leaks=0
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt

- name: Run tests
Expand Down Expand Up @@ -113,7 +113,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system -r docs/requirements.txt
- name: Test documentation
env:
Expand Down Expand Up @@ -147,7 +147,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system -r docs/requirements.txt
- name: Render documentation
run: |
Expand All @@ -173,7 +173,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt

- name: Run tests
Expand Down Expand Up @@ -205,7 +205,7 @@ jobs:
python-version: 3.12
- name: Install JAX
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[cuda12]
- name: Build and install example project
run: uv pip install --system ./examples/ffi[test]
Expand Down
36 changes: 16 additions & 20 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,38 @@ jobs:
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
$PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
$PYTHON -m pip uninstall -y jax jaxlib libtpu
$PYTHON -m uv pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
# Build and install jaxlib at head
$PYTHON build/build.py build --wheels=jaxlib \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose

$PYTHON -m pip install dist/*.whl

# Install "jax" at head
$PYTHON -m pip install -U -e .

# Install libtpu
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Install jaxlib, "jax" at head, and libtpu
$PYTHON -m uv pip install dist/*.whl \
-U -e . \
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
$PYTHON -m pip install .[tpu] \
$PYTHON -m uv pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html

elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
$PYTHON -m uv pip install \
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
install requests

elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
$PYTHON -m uv pip install \
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/cloud-tpu-ci-presubmit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ jobs:
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
Expand All @@ -86,7 +85,7 @@ jobs:
--verbose

# Install libtpu
$JAXCI_PYTHON -m pip install --pre libtpu \
$JAXCI_PYTHON -m uv pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system .[ci]
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
pip install uv~=0.5.30
uv pip install --system .[ci] --system pytest-xdist -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.numpy
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/metal_plugin_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ jobs:
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install absl-py pytest
pip install uv~=0.5.30
uv pip install -U pip numpy wheel absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
uv pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
fi;
cd jax
pip install .
pip install jax-metal
uv pip install . jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ jobs:
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tsan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ jobs:

export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH

python3 -m pip install -r requirements/build_requirements.txt
python3 -m pip install uv~=0.5.30
# Make sure to install a compatible Cython version (master branch is best for now)
python3 -m pip install -U git+https://github.com/cython/cython
python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython

CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized

Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/upstream-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install .[ci] -r build/test-requirements.txt
pip install uv~=0.5.30
uv pip install .[ci] -r build/test-requirements.txt
- name: Install numpy & scipy development versions
run: |
pip install \
uv pip install \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
--no-deps \
--pre \
Expand Down
9 changes: 5 additions & 4 deletions .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ jobs:
BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC"
JAXLIB_RELEASE: true
run: |
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
python -m pip install uv~=0.5.30
python -m uv pip install -r build/test-requirements.txt \
--upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
Expand All @@ -57,7 +58,7 @@ jobs:
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
python -m pip install -e ${{ github.workspace }}
python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \
-e ${{ github.workspace }}
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
8 changes: 4 additions & 4 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
JAXLIB_NIGHTLY: true # Tag the wheels as dev versions
run: |
cd jax
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
python -m pip install uv~=0.5.30
python -m uv pip install -r build/test-requirements.txt --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
Expand All @@ -67,7 +67,7 @@ jobs:
PY_COLORS: 1
run: |
cd jax
python -m pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib
python -m pip install -e ${{ github.workspace }}\jax
python -m uv pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib \
-e ${{ github.workspace }}\jax
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
12 changes: 9 additions & 3 deletions ci/utilities/install_wheels_locally.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ fi
echo "Installing the following wheels:"
echo "${WHEELS[@]}"

# Install `uv` if it's not already installed. `uv` is much faster than pip for
# installing Python packages.
if ! command -v uv >/dev/null 2>&1; then
pip install uv~=0.5.30
fi

# On Windows, convert MSYS Linux-like paths to Windows paths.
if [[ $(uname -s) =~ "MSYS_NT" ]]; then
"$JAXCI_PYTHON" -m pip install $(cygpath -w "${WHEELS[@]}")
"$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}")
else
"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}"
"$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}"
fi

if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
echo "Installing the JAX package in editable mode at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m pip install -U -e .
"$JAXCI_PYTHON" -m uv pip install -U -e .
fi
Loading