Skip to content

Commit

Permalink
feat: package updates with python311
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Nov 14, 2024
1 parent adc5ea9 commit ce99901
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 48 deletions.
8 changes: 4 additions & 4 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**31
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]

Expand All @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:

def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down
29 changes: 28 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04
RUN echo "Setting up machine"
RUN apt-get update
RUN apt-get install -y curl tar
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg

# Install prerequisites
RUN apt-get update && apt-get install -y \
wget \
build-essential \
zlib1g-dev \
libncurses5-dev \
libssl-dev \
libreadline-dev \
libffi-dev \
curl \
libbz2-dev \
liblzma-dev

# Download and install Python 3.11
RUN cd /tmp \
&& wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \
&& tar -xvzf Python-3.11.0.tgz \
&& cd Python-3.11.0 \
&& ./configure --enable-optimizations \
&& make -j$(nproc) \
&& make altinstall

# Create symlinks for python and pip (use 'pip' instead of 'pip3')
RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \
&& ln -s /usr/local/bin/pip3.11 /usr/bin/pip

RUN apt-get install libtcmalloc-minimal4
RUN apt-get install unzip
RUN apt-get install pigz
Expand Down
86 changes: 43 additions & 43 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers =
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering :: Artificial Intelligence

[options]
Expand All @@ -34,22 +35,22 @@ setup_requires =
setuptools_scm
# Dependencies of the project:
install_requires =
absl-py==1.4.0
absl-py==2.1.1
# Pin to avoid unpinned install in dependencies that requires Python>=3.9.
networkx==3.1
docker==7.0.0
numpy>=1.23
pandas>=2.0.1
tensorflow==2.12.0
tensorflow-datasets==4.9.2
tensorflow-probability==0.20.0
tensorflow-addons==0.20.0
networkx==3.2.1
docker==7.1.0
numpy>=1.26.4
pandas==2.2.3
tensorflow==2.18.0
tensorflow-datasets==4.9.7
tensorflow-addons==0.23.0
gputil==1.4.0
psutil==5.9.5
clu==0.0.7
matplotlib>=3.7.2
psutil==6.1.0
clu==0.0.12
matplotlib>=3.9.2
tabulate==0.9.0
python_requires = >=3.8
wandb==0.18.7
python_requires = >=3.11


###############################################################################
Expand Down Expand Up @@ -79,78 +80,77 @@ full_dev =

# Dependencies for developing the package
dev =
isort==5.12.0
pylint==2.17.4
pytest==7.3.1
yapf==0.33.0
pre-commit==3.3.1
isort==5.13.2
pylint==3.3.1
pytest==8.3.3
yapf==0.43.0
pre-commit==4.0.1

# Workloads #
criteo1tb =
scikit-learn==1.2.2
scikit-learn==1.5.2

fastmri =
h5py==3.8.0
scikit-image==0.20.0
h5py==3.12.1
scikit-image==0.24.0

ogbg =
jraph==0.0.6.dev0
scikit-learn==1.2.2
scikit-learn==1.5.2

librispeech_conformer =
sentencepiece==0.1.99
tensorflow-text==2.12.1
sentencepiece==0.2.0
tensorflow-text==2.18.0
pydub==0.25.1

wmt =
sentencepiece==0.1.99
tensorflow-text==2.12.1
sacrebleu==1.3.1
sentencepiece==0.2.0
tensorflow-text==2.18.0
sacrebleu==2.4.3

# Frameworks #

# JAX Core
jax_core_deps =
flax==0.6.10
optax==0.1.5
flax==0.10.1
optax==0.2.4
# Fix chex (optax dependency) version.
# Not fixing it can raise dependency issues with our
# jax version.
# Todo(kasimbeg): verify if this is necessary after we
# upgrade jax.
chex==0.1.7
ml_dtypes==0.2.0
protobuf==4.25.3
chex==0.1.87
ml_dtypes==0.4.1
protobuf==4.25.5


# JAX CPU
jax_cpu =
jax==0.4.10
jaxlib==0.4.10
jax==0.4.35
jaxlib==0.4.35
%(jax_core_deps)s

# JAX GPU
# Note this installs both jax and jaxlib.
jax_gpu =
jax==0.4.10
jaxlib==0.4.10+cuda12.cudnn88
jax==0.4.35
jaxlib==0.4.35
jax-cuda12-plugin[with_cuda]==0.4.35
jax-cuda12-pjrt==0.4.35
%(jax_core_deps)s

# PyTorch CPU
pytorch_cpu =
torch==2.1.0
torchvision==0.16.0
torch==2.5.0
torchvision==0.20.0

# PyTorch GPU
# Note: omit the cuda suffix and installing from the appropriate
# wheel will result in using locally installed CUDA.
pytorch_gpu =
torch==2.1.0
torchvision==0.16.0
torch==2.5.0
torchvision==0.20.0

# wandb
wandb =
wandb==0.16.5

###############################################################################
# Linting Configurations #
Expand Down

0 comments on commit ce99901

Please sign in to comment.