diff --git a/.github/workflows/shampoo_tests.yml b/.github/workflows/shampoo_tests.yml new file mode 100644 index 000000000..c50096bfe --- /dev/null +++ b/.github/workflows/shampoo_tests.yml @@ -0,0 +1,96 @@ +name: Containerized Regression Tests for Shampoo + +on: + pull_request: + branches: + - 'dev' + +jobs: + build_and_push_jax_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=jax + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + fastmri_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + imagenet_resnet_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + imagenet_vit_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + ogbg_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + criteo_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + librispeech_conformer_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + librispeech_deepspeech_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + wmt_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/shampoo/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/shampoo/tuning_search_space.json -e tests/regression_tests/shampoo -m 10 -c False -o True -r false + \ No newline at end of file diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..1336f3c80 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -1318,7 +1318,7 @@ def distributed_shampoo( ### inverse_failure_threshold=0.1, moving_average_for_momentum=True, - skip_preconditioning_dim_size_gt=0, + skip_preconditioning_dim_size_gt=4096, clip_by_scaled_gradient_norm=None, precision=lax.Precision.HIGHEST, tensordot_precision=None,