Skip to content

Commit

Permalink
JAX config updates and expose penalty_param (#894)
Browse files Browse the repository at this point in the history
* jax config updates and expose penalty_param

bump jax version

* make linter and formatter happy
  • Loading branch information
cvsik authored Sep 3, 2024
1 parent e33a29e commit b6191ab
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 12 deletions.
4 changes: 2 additions & 2 deletions dev_tools/requirements/envs/dev.env.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ iniconfig==2.0.0
# via pytest
isort==5.13.2
# via pylint
jax==0.4.23
jax==0.4.31
# via -r deps/resource_estimates_runtime.txt
jaxlib==0.4.23
jaxlib==0.4.31
# via -r deps/resource_estimates_runtime.txt
jsonschema==4.21.0
# via nbformat
Expand Down
4 changes: 2 additions & 2 deletions dev_tools/requirements/envs/pytest-extra.env.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ iniconfig==2.0.0
# via
# -c envs/dev.env.txt
# pytest
jax==0.4.23
jax==0.4.31
# via
# -c envs/dev.env.txt
# -r deps/resource_estimates_runtime.txt
jaxlib==0.4.23
jaxlib==0.4.31
# via
# -c envs/dev.env.txt
# -r deps/resource_estimates_runtime.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
from pyscf.pbc import scf
from scipy.optimize import minimize

from jax.config import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import jax.typing as jnpt

Expand Down
6 changes: 5 additions & 1 deletion src/openfermion/resource_estimates/thc/factorize_thc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def thc_via_cp3(
bfgs_maxiter=5000,
random_start_thc=True,
verify=False,
penalty_param=None,
):
"""
THC-CP3 performs an SVD decomposition of the eri matrix followed by a CP
Expand All @@ -36,6 +37,7 @@ def thc_via_cp3(
random_start_thc - Perform random start for CP3.
If false perform HOSVD start.
verify - check eri properties. Default is False
penalty_param - penalty parameter for L2 regularization. Default is None.
returns:
eri_thc - (N x N x N x N) reconstructed ERIs from THC factorization
Expand Down Expand Up @@ -115,7 +117,9 @@ def thc_via_cp3(
if perform_bfgs_opt:
x = np.hstack((thc_leaf.ravel(), thc_central.ravel()))
# lbfgs_start_time = time.time()
x = lbfgsb_opt_thc_l2reg(eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter)
x = lbfgsb_opt_thc_l2reg(
eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter, penalty_param=penalty_param
)
# lbfgs_calc_time = time.time() - lbfgs_start_time
thc_leaf = x[: norb * nthc].reshape(nthc, norb) # leaf tensor nthc x norb
thc_central = x[norb * nthc : norb * nthc + nthc * nthc].reshape(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# coverage:ignore
# pylint: disable=wrong-import-position
import os
from uuid import uuid4
import h5py
import numpy
import numpy.random
import numpy.linalg
from scipy.optimize import minimize

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.config import config
from jax import jit, grad
from .adagrad import adagrad
from .thc_objectives import (
Expand All @@ -22,7 +26,6 @@
# set mkl thread count for numpy einsum/tensordot calls
# leave one CPU un used so we can still access this computer
os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1)
config.update("jax_enable_x64", True)


class CallBackStore:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# coverage:ignore
# pylint: disable=wrong-import-position
import os
from uuid import uuid4
import scipy.optimize

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax.config import config
from jax import jit, grad
import h5py
import numpy
Expand All @@ -15,7 +20,6 @@
# set mkl thread count for numpy einsum/tensordot calls
# leave one CPU un used so we can still access this computer
os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1)
config.update("jax_enable_x64", True)


def thc_objective_jax(xcur, norb, nthc, eri):
Expand Down

0 comments on commit b6191ab

Please sign in to comment.