Skip to content

Commit

Permalink
Use StableRNGs.jl to make tests more consistent (#80)
Browse files Browse the repository at this point in the history
* Use StableRNGs to make tests more consistent

* Bump to v0.4.2
  • Loading branch information
mtfishman authored May 17, 2024
1 parent 703ac67 commit a435b7b
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorTDVP"
uuid = "25707e16-a4db-4a07-99d9-4d67b7af0342"
authors = ["Matthew Fishman <[email protected]> and contributors"]
version = "0.4.1"
version = "0.4.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 3 additions & 1 deletion test/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using ITensors: ITensors, dag, delta, denseblocks
using ITensors: MPO, OpSum, apply, contract, inner, randomMPS, siteinds, truncate!
using ITensorTDVP: ITensorTDVP
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
@testset "Contract MPO (eltype=$elt, conserve_qns=$conserve_qns)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
Expand All @@ -10,7 +11,8 @@ using Test: @test, @test_throws, @testset

N = 20
s = siteinds("S=1/2", N; conserve_qns)
psi = randomMPS(elt, s, j -> isodd(j) ? "" : ""; linkdims=8)
rng = StableRNG(1234)
psi = randomMPS(rng, elt, s, j -> isodd(j) ? "" : ""; linkdims=8)
os = OpSum()
for j in 1:(N - 1)
os += 0.5, "S+", j, "S-", j + 1
Expand Down
4 changes: 3 additions & 1 deletion test/test_dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using ITensors: ITensors, MPO, OpSum, inner, randomMPS, siteinds
using ITensorTDVP: ITensorTDVP
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
@testset "DMRG (eltype=$elt, nsite=$nsite, conserve_qns=$conserve_qns)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
Expand All @@ -18,7 +19,8 @@ using Test: @test, @test_throws, @testset
os += "Sz", j, "Sz", j + 1
end
H = MPO(elt, os, s)
psi = randomMPS(elt, s, j -> isodd(j) ? "" : ""; linkdims=20)
rng = StableRNG(1234)
psi = randomMPS(rng, elt, s, j -> isodd(j) ? "" : ""; linkdims=20)
nsweeps = 10
maxdim = [10, 20, 40, 100]
@test_throws ErrorException ITensorTDVP.dmrg(H, psi; maxdim, cutoff, nsite)
Expand Down
6 changes: 4 additions & 2 deletions test/test_dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using ITensors: ITensors, MPO, MPS, OpSum, ProjMPO, inner, siteinds
using ITensorTDVP: dmrg_x
using Random: Random
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
@testset "DMRG-X (eltype=$elt, conserve_qns=$conserve_qns)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
Expand All @@ -27,9 +28,10 @@ using Test: @test, @test_throws, @testset
Random.seed!(12)
W = 12
# Random fields h ∈ [-W, W]
h = W * (2 * rand(real(elt), n) .- 1)
rng = StableRNG(1234)
h = W * (2 * rand(rng, real(elt), n) .- 1)
H = MPO(elt, heisenberg(n; h), s)
initstate = rand(["", ""], n)
initstate = rand(rng, ["", ""], n)
ψ = MPS(elt, s, initstate)
@test_throws ErrorException dmrg_x(H, ψ; nsite=2, maxdim=20, cutoff=1e-10)
dmrg_x_kwargs = (; nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0)
Expand Down
21 changes: 12 additions & 9 deletions test/test_linsolve.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@eval module $(gensym())
using ITensors: ITensors, MPO, OpSum, apply, randomMPS, siteinds
using ITensors: scalartype
using ITensors.ITensorMPS: MPO, OpSum, apply, randomMPS, siteinds
using ITensorTDVP: ITensorTDVP, dmrg
using KrylovKit: linsolve
using LinearAlgebra: norm
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
using Random: Random
@testset "linsolve (eltype=$elt, conserve_qns=$conserve_qns)" for elt in (
Expand All @@ -18,25 +20,26 @@ using Random: Random
os += 0.5, "S-", j, "S+", j + 1
os += "Sz", j, "Sz", j + 1
end
H = ITensors.convert_leaf_eltype(elt, MPO(os, s))
H = MPO(elt, os, s)
state = [isodd(n) ? "Up" : "Dn" for n in 1:N]
Random.seed!(1234)
x_c = randomMPS(elt, s, state; linkdims=2)
rng = StableRNG(1234)
x_c = randomMPS(rng, elt, s, state; linkdims=2)
e, x_c = dmrg(H, x_c; nsweeps=10, cutoff=1e-6, maxdim=20, outputlevel=0)
@test ITensors.scalartype(x_c) == elt
@test scalartype(x_c) == elt
# Compute `b = H * x_c`
b = apply(H, x_c; cutoff=1e-8)
@test ITensors.scalartype(b) == elt
@test scalartype(b) == elt
# Starting guess
x0 = x_c + elt(0.05) * randomMPS(elt, s, state; linkdims=2)
@test ITensors.scalartype(x0) == elt
rng = StableRNG(1234)
x0 = x_c + elt(0.05) * randomMPS(rng, elt, s, state; linkdims=2)
@test scalartype(x0) == elt
nsweeps = 10
cutoff = 1e-5
maxdim = 20
updater_kwargs = (; tol=1e-4, maxiter=20, krylovdim=30, ishermitian=true)
@test_throws ErrorException linsolve(H, b, x0; cutoff, maxdim, updater_kwargs)
x = linsolve(H, b, x0; nsweeps, cutoff, maxdim, updater_kwargs)
@test ITensors.scalartype(x) == elt
@test scalartype(x) == elt
@test norm(x - x_c) < 1e-2
end
end
15 changes: 10 additions & 5 deletions test/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using ITensorTDVP: ITensorTDVP, tdvp
using KrylovKit: exponentiate
using LinearAlgebra: norm
using Observers: observer
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "Basic TDVP (eltype=$elt)" for elt in elts
Expand All @@ -33,7 +34,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
os += "Sz", j, "Sz", j + 1
end
H = MPO(elt, os, s)
ψ0 = randomMPS(elt, s; linkdims=10)
rng = StableRNG(1234)
ψ0 = randomMPS(rng, elt, s; linkdims=10)
time_step = elt(0.1) * im
# Time evolve forward:
ψ1 = tdvp(H, -time_step, ψ0; cutoff, nsite=1)
Expand Down Expand Up @@ -71,15 +73,16 @@ end
H1 = MPO(elt, os1, s)
H2 = MPO(elt, os2, s)
Hs = [H1, H2]
ψ0 = randomMPS(elt, s; linkdims=10)
rng = StableRNG(1234)
ψ0 = randomMPS(rng, elt, s; linkdims=10)
ψ1 = tdvp(Hs, -elt(0.1) * im, ψ0; cutoff, nsite=1)
@test ITensors.scalartype(ψ1) === complex(elt)
@test norm(ψ1) 1 rtol = eps(real(elt))
## Should lose fidelity:
#@test abs(inner(ψ0,ψ1)) < 0.9
# Average energy should be conserved:
@test real(sum(H -> inner(ψ1', H, ψ1), Hs)) sum(H -> inner(ψ0', H, ψ0), Hs) rtol =
2 * eps(real(elt))
4 * eps(real(elt))
# Time evolve backwards:
ψ2 = tdvp(Hs, elt(0.1) * im, ψ1; cutoff)
@test ITensors.scalartype(ψ2) === complex(elt)
Expand All @@ -98,7 +101,8 @@ end
os += "Sz", j, "Sz", j + 1
end
H = MPO(os, s)
ψ0 = randomMPS(s; linkdims=10)
rng = StableRNG(1234)
ψ0 = randomMPS(rng, s; linkdims=10)
function updater(PH, state0; internal_kwargs, kwargs...)
return exponentiate(PH, internal_kwargs.time_step, state0; kwargs...)
end
Expand Down Expand Up @@ -252,7 +256,8 @@ end
os += "Sz", j, "Sz", j + 1
end
H = MPO(os, s)
state = randomMPS(s; linkdims=2)
rng = StableRNG(1234)
state = randomMPS(rng, s; linkdims=2)
state2 = deepcopy(state)
trange = 0.0:tau:ttotal
for (step, t) in enumerate(trange)
Expand Down
14 changes: 8 additions & 6 deletions test/test_tdvp_time_dependent.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
@eval module $(gensym())
using ITensors: ITensors, Index, QN, contract, randomITensor
using ITensors: ITensors, Index, QN, contract, randomITensor, scalartype
using ITensors.ITensorMPS: MPO, MPS, ProjMPO, ProjMPOSum, randomMPS, position!, siteinds
using ITensorTDVP: ITensorTDVP, TimeDependentSum, tdvp
using LinearAlgebra: norm
using StableRNGs: StableRNG
using Test: @test, @test_skip, @testset
include(joinpath(pkgdir(ITensorTDVP), "examples", "03_models.jl"))
include(joinpath(pkgdir(ITensorTDVP), "examples", "03_updaters.jl"))
Expand All @@ -17,7 +18,8 @@ include(joinpath(pkgdir(ITensorTDVP), "examples", "03_updaters.jl"))
H = MPO(elt, s, "I")
H⃗ = (H, H)
region = 2:3
ψ = randomMPS(elt, s, j -> isodd(j) ? "" : ""; linkdims=2)
rng = StableRNG(1234)
ψ = randomMPS(rng, elt, s, j -> isodd(j) ? "" : ""; linkdims=2)
H⃗ᵣ = ProjMPO.(H⃗)
map(Hᵣ -> position!(Hᵣ, ψ, first(region)), H⃗ᵣ)
∑Hᵣ = ProjMPOSum(collect(H⃗))
Expand Down Expand Up @@ -87,10 +89,10 @@ include(joinpath(pkgdir(ITensorTDVP), "examples", "03_updaters.jl"))
abstol=tol,
)

@test ITensors.scalartype(ψ₀) == complex(elt)
@test ITensors.scalartype(ψₜ_ode) == complex(elt)
@test ITensors.scalartype(ψₜ_krylov) == complex(elt)
@test ITensors.scalartype(ψₜ_full) == complex(elt)
@test scalartype(ψ₀) == complex(elt)
@test scalartype(ψₜ_ode) == complex(elt)
@test scalartype(ψₜ_krylov) == complex(elt)
@test scalartype(ψₜ_full) == complex(elt)
@test norm(ψ₀) 1
@test norm(ψₜ_ode) 1
@test norm(ψₜ_krylov) 1 rtol = (eps(real(elt)))
Expand Down

0 comments on commit a435b7b

Please sign in to comment.