diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 4670be5..6d78229 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,30 +1672,43 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites +* `truncation_error=false` - If `true`, will return a vector containing the trucation error calculated at each bond. """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) end function truncate!( - ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), kwargs... + ::Algorithm"frobenius", + M::AbstractMPS; + site_range=1:length(M), + truncation_error = false, + kwargs..., ) N = length(M) - + nbonds = N - 1 + truncation_errors = zeros(real(scalartype(M)), nbonds) # Left-orthogonalize all tensors to make # truncations controlled orthogonalize!(M, last(site_range)) # Perform truncations in a right-to-left sweep - for j in reverse((first(site_range) + 1):last(site_range)) + js = reverse((first(site_range) + 1):last(site_range)) + for i in eachindex(js) + j = js[i] rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) - U, S, V = svd(M[j], rinds; lefttags=ltags, kwargs...) + U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...) + truncation_errors[i] = spec.truncerr M[j] = U M[j - 1] *= (S * V) setrightlim!(M, j) end - return M + if truncation_error + return truncation_errors + else + return M + end end function truncate(ψ0::AbstractMPS; kwargs...) diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index 65eed0e..dfad897 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -755,6 +755,17 @@ end truncate!(M; site_range=3:7, maxdim=2) @test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2] end + + @testset "truncate! with truncation_error" begin + nsites = 10 + nbonds = nsites - 1 + mps_ = basicRandomMPS(nsites; dim=10) + truncation_errors = truncate!(mps_, maxdim=3, cutoff=1E-3, truncation_error=true) + @test length(truncation_errors) == nbonds + @test all(truncation_errors .>= 0.0) + end + + end @testset "Other MPS methods" begin