Skip to content

Commit

Permalink
Support tensorflow 2.16.1 and separate CI run for tensorflow estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Mar 16, 2024
1 parent 5861b23 commit 549094f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
18 changes: 11 additions & 7 deletions .github/workflows/tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3
Expand All @@ -31,14 +31,18 @@ jobs:
pip install git+https://github.com/optuna/optuna-integration.git
python -c 'import optuna_integration'
# NOTE(nabenabe0928): Got "AttributeError: module 'tensorflow' has no attribute 'estimator'".
# TODO(nabenabe0928): Remove this version constraint.
pip install "tensorflow<2.16.1"
pip install -r tensorflow/requirements.txt
- name: Run examples
- name: Run example of TensorFlow eager
run: |
python tensorflow/tensorflow_eager_simple.py
python tensorflow/tensorflow_estimator_simple.py
python tensorflow/tensorflow_estimator_integration.py
env:
OMP_NUM_THREADS: 1
- name: Run examples of TensorFlow estimator
run: |
if [ "${{ matrix.python-version }}" != "3.12" ] ; then
pip install "tensorflow<2.16.0"
python tensorflow/tensorflow_estimator_simple.py
python tensorflow/tensorflow_estimator_integration.py
fi
env:
OMP_NUM_THREADS: 1
4 changes: 4 additions & 0 deletions tensorflow/tensorflow_estimator_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

import optuna
from optuna.trial import TrialState
from packaging import version
import tensorflow_datasets as tfds

import tensorflow as tf


if version.parse(tf.__version__) >= version.parse("2.16.0"):
raise RuntimeError("tensorflow<2.16.0 is required for this example.")

# TODO(crcrpar): Remove the below three lines once everything is ok.
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
opener = urllib.request.build_opener()
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tensorflow_estimator_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

import optuna
import tensorflow_datasets as tfds
from packaging import version

import tensorflow as tf


if version.parse(tf.__version__) >= version.parse("2.16.0"):
raise RuntimeError("tensorflow<2.16.0 is required for this example.")

# TODO(crcrpar): Remove the below three lines once everything is ok.
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
opener = urllib.request.build_opener()
Expand Down

0 comments on commit 549094f

Please sign in to comment.