Skip to content

Commit

Permalink
Add forced cuda version to workshop install (#78)
Browse files Browse the repository at this point in the history
* Add forced cuda version to `workshop install`

* Add changelog for #78
  • Loading branch information
linusyh authored Feb 23, 2024
1 parent 2efa6c7 commit 14e5181
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

* Adds `InverseSquareRoot` LR scheduler [#71](https://github.com/a-r-j/ProteinWorkshop/pull/71)

### Command
* Adds `--force-cuda-version` to `workshop install` [#78](https://github.com/a-r-j/ProteinWorkshop/pull/78)


### 0.2.5 (28/12/2023)

Expand Down
11 changes: 10 additions & 1 deletion proteinworkshop/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def main():
default=False,
required=False,
)

install_parser.add_argument(
"--force-cuda-version",
type=int,
help="set cuda version manually, ignore automatic detection. e.g. 121 for CUDA 12.1",
default=None,
required=False,
)

install_parser.add_argument(
"dependency", choices=["pyg"], help="dependency help"
)
Expand Down Expand Up @@ -112,7 +121,7 @@ def main():
# lazy import
from .install_pyg import _install_pyg

_install_pyg(args.force_reinstall)
_install_pyg(args.force_reinstall, args.force_cuda_version)

elif args.command == "download":
if args.dataset == "pdb":
Expand Down
7 changes: 6 additions & 1 deletion proteinworkshop/scripts/install_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from loguru import logger


def _install_pyg(force_reinstall: bool = False):
def _install_pyg(force_reinstall: bool = False,
force_cuda_version: int = None):
torch_version = torch.__version__
cuda_version = (
torch.version.cuda.replace(".", "")
Expand All @@ -13,6 +14,10 @@ def _install_pyg(force_reinstall: bool = False):
)
logger.info(f"Detected PyTorch version: {torch_version}")
logger.info(f"Detected CUDA version: {cuda_version}")
if force_cuda_version is not None:
logger.info(f"Forcing CUDA version to {force_cuda_version}")
cuda_version = force_cuda_version

logger.info(
f"Installing PyTorch Geometric for PyTorch {torch_version} and CUDA {cuda_version}"
)
Expand Down

0 comments on commit 14e5181

Please sign in to comment.