Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ include("algorithms/truncation/truncationschemes.jl")
include("algorithms/truncation/fullenv_truncation.jl")
include("algorithms/truncation/bond_truncation.jl")

include("algorithms/time_evolution/trotter_gate.jl")
include("algorithms/time_evolution/apply_gate.jl")
include("algorithms/time_evolution/apply_mpo.jl")
include("algorithms/time_evolution/trotter_gate.jl")
include("algorithms/time_evolution/time_evolve.jl")
include("algorithms/time_evolution/simpleupdate.jl")
include("algorithms/time_evolution/simpleupdate3site.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/bp/beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function leading_boundary(env₀::BPEnv, network::InfiniteSquareNetwork, alg::Be
end
function leading_boundary(env₀::BPEnv, state::InfiniteState, alg::BeliefPropagation)
if alg.bipartite
@assert _state_bipartite_check(state)
_is_bipartite(state) || error("Input state is not bipartite")
end
return leading_boundary(env₀, InfiniteSquareNetwork(state), alg)
end
Expand Down
14 changes: 2 additions & 12 deletions src/algorithms/bp/gaugefix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ Algorithm for gauging PEPS with belief propagation fixed point messages.
# TODO: add options
end

function _bpenv_bipartite_check(env::BPEnv)
for (r, c) in Iterators.product(1:2, 1:2)
r′, c′ = _next(r, 2), _next(c, 2)
if !all(env[:, r, c] .== env[:, r′, c′])
return false
end
end
return true
end

"""
gauge_fix(psi::Union{InfinitePEPS, InfinitePEPO}, alg::BPGauge, env::BPEnv)

Expand All @@ -25,7 +15,7 @@ an [`InfinitePEPO`](@ref) interpreted as purified state with two physical legs)
using fixed point environment `env` of belief propagation.
"""
function gauge_fix(psi::InfinitePEPS, alg::BPGauge, env::BPEnv)
bipartite = _state_bipartite_check(psi) && _bpenv_bipartite_check(env)
bipartite = _is_bipartite(psi) && _is_bipartite(env)
psi′ = copy(psi)
XXinv = map(eachcoordinate(psi, 1:2)) do I
_, X, Xinv = _bp_gauge_fix!(CartesianIndex(I), psi′, env)
Expand All @@ -52,7 +42,7 @@ function gauge_fix(psi::InfinitePEPO, alg::BPGauge, env::BPEnv)
Fs = map(Base.Fix2(getindex, 2), psi_Fs)
psi′, XXinv = gauge_fix(InfinitePEPS(psi′), alg, env)
# convert back to iPEPO
psi′ = map(zip(psi′.A, Fs)) do (t, F)
psi′ = map(psi′.A, Fs) do t, F
return F' * t
end
psi′ = reshape(psi′, (Nr, Nc, 1))
Expand Down
11 changes: 4 additions & 7 deletions src/algorithms/time_evolution/apply_gate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const NNGate{T, S} = AbstractTensorMap{T, S, 2, 2}

"""
Apply 1-site `gate` on the PEPS or PEPO tensor `a`.
"""
Expand Down Expand Up @@ -125,20 +127,15 @@ Apply 2-site `gate` on the reduced matrices `a`, `b`
"""
function _apply_gate(
a::AbstractTensorMap, b::AbstractTensorMap,
gate::AbstractTensorMap{T, S, 2, 2}, trunc::TruncationStrategy
) where {T <: Number, S <: ElementarySpace}
gate::NNGate, trunc::TruncationStrategy
)
V = space(b, 1)
need_flip = isdual(V)
if isdual(space(a, 2))
@tensor a2b2[-1 -2; -3 -4] := gate[1 2; -2 -3] * a[-1 1 3] * b[3 2 -4]
else
@tensor a2b2[-1 -2; -3 -4] := gate[-2 -3; 1 2] * a[-1 1 3] * b[3 2 -4]
end
trunc = if trunc isa FixedSpaceTruncation
need_flip ? truncspace(flip(V)) : truncspace(V)
else
trunc
end
a, s, b, ϵ = svd_trunc!(a2b2; trunc)
a, b = absorb_s(a, s, b)
if need_flip
Expand Down
10 changes: 2 additions & 8 deletions src/algorithms/time_evolution/apply_mpo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,8 @@ function _get_allprojs(
N = length(Ms)
Rs, Ls = _get_allRLs(Ms)
@assert length(truncs) == N - 1
projs_errs = map(1:(N - 1)) do i
trunc = if isa(truncs[i], FixedSpaceTruncation)
tspace = space(Ms[i + 1], 1)
isdual(tspace) ? truncspace(flip(tspace)) : truncspace(tspace)
else
truncs[i]
end
return _proj_from_RL(Rs[i], Ls[i]; trunc)
projs_errs = map(Rs, Ls, truncs) do R, L, trunc
return _proj_from_RL(R, L; trunc)
end
Pas = map(Base.Fix2(getindex, 1), projs_errs)
wts = map(Base.Fix2(getindex, 2), projs_errs)
Expand Down
10 changes: 6 additions & 4 deletions src/algorithms/time_evolution/gaugefix_su.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,24 @@ end
Fix the gauge of `psi` using trivial simple update.
"""
function gauge_fix(psi::InfiniteState, alg::SUGauge)
time0 = time()
gates = _trivial_gates(scalartype(psi), physicalspace(psi))
su_alg = SimpleUpdate(; trunc = FixedSpaceTruncation(), bipartite = _state_bipartite_check(psi))
trunc = _get_fixedspacetrunc(psi)
su_alg = SimpleUpdate(; trunc, bipartite = _is_bipartite(psi))
wts0 = SUWeight(psi)
# use default constructor to avoid calculation of exp(-H * 0)
evolver = TimeEvolver(su_alg, 0.0, alg.maxiter, gates, SUState(0, 0.0, psi, wts0))
for (i, (psi′, wts, info)) in enumerate(evolver)
ϵ = compare_weights(wts, wts0)
if i >= alg.miniter && ϵ < alg.tol
@info "Trivial SU conv $i: |Δλ| = $ϵ."
@info "Trivial SU conv $i: |Δλ| = $ϵ, time = $(time() - time0) s"
return psi′, wts, ϵ
end
if i == alg.maxiter
@warn "Trivial SU cancel $i: |Δλ| = $ϵ."
@warn "Trivial SU cancel $i: |Δλ| = $ϵ, time = $(time() - time0) s"
return psi′, wts, ϵ
end
wts0 = deepcopy(wts)
wts0 = wts
end
return
end
101 changes: 69 additions & 32 deletions src/algorithms/time_evolution/simpleupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ Algorithm struct for simple update (SU) of InfinitePEPS or InfinitePEPO.

$(TYPEDFIELDS)
"""
@kwdef struct SimpleUpdate <: TimeEvolution
@kwdef struct SimpleUpdate{T <: TruncationStrategy} <: TimeEvolution
"Truncation strategy for bonds updated by Trotter gates"
trunc::TruncationStrategy
trunc::T
"When true (or false), the Trotter gate is `exp(-H dt)` (or `exp(-iH dt)`)"
imaginary_time::Bool = true
"When true, force decomposition of nearest neighbor gates to MPOs."
force_mpo::Bool = false
"When true, assume bipartite unit cell structure"
bipartite::Bool = false
"(Only applicable to InfinitePEPO)
"(Only applicable to InfinitePEPO)
When true, the PEPO is regarded as a purified PEPS, and updated as
`|ρ(t + dt)⟩ = exp(-H dt/2) |ρ(t)⟩`.
When false, the PEPO is updated as
When false, the PEPO is updated as
`ρ(t + dt) = exp(-H dt/2) ρ(t) exp(-H dt/2)`."
purified::Bool = true
end
Expand Down Expand Up @@ -62,6 +62,12 @@ function TimeEvolver(
# create Trotter gates
gate = trotterize(H, dt′; symmetrize_gates, force_mpo = alg.force_mpo)
state = SUState(0, t0, psi0, env0)
# convert FixedSpaceTruncation to site-dependent `truncspace`s
if alg.trunc isa FixedSpaceTruncation
trunc = _get_fixedspacetrunc(psi0)
@reset alg.trunc = trunc
end
# TODO: bipartite check for alg.trunc after equality is defined for all kinds of truncation strategies
# TODO: check gates for bipartite case
return TimeEvolver(alg, dt, nstep, gate, state)
end
Expand Down Expand Up @@ -89,8 +95,8 @@ function _su_iter!(
sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate
)
Nr, Nc = size(state)
truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc))
@assert length(sites) == 2 && length(truncs) == 1
@assert length(sites) == 2
trunc = only(_get_cluster_trunc(alg.trunc, sites, (Nr, Nc)))
Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false)
normalize!.(Ms, Inf)
# rotate
Expand All @@ -101,7 +107,7 @@ function _su_iter!(
gate_axs = alg.purified ? (1:1) : (1:2)
for gate_ax in gate_axs
X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true)
a, s, b, ϵ′ = _apply_gate(a, b, gate, truncs[1])
a, s, b, ϵ′ = _apply_gate(a, b, gate, trunc)
ϵ = max(ϵ, ϵ′)
A, B = _qr_bond_undo(X, a, b, Y)
end
Expand Down Expand Up @@ -148,14 +154,14 @@ function su_iter(
(!alg.bipartite) && continue
if d == 1
rp1, cp1 = _next(r, Nr), _next(c, Nc)
state2[rp1, cp1] = deepcopy(state2[r, c])
state2[rp1, c] = deepcopy(state2[r, cp1])
env2[1, rp1, cp1] = deepcopy(env2[1, r, c])
state2[rp1, cp1] = copy(state2[r, c])
state2[rp1, c] = copy(state2[r, cp1])
env2[1, rp1, cp1] = copy(env2[1, r, c])
else
rm1, cm1 = _prev(r, Nr), _prev(c, Nc)
state2[rm1, cm1] = deepcopy(state2[r, c])
state2[r, cm1] = deepcopy(state2[rm1, c])
env2[2, rm1, cm1] = deepcopy(env2[2, r, c])
state2[rm1, cm1] = copy(state2[r, c])
state2[r, cm1] = copy(state2[rm1, c])
env2[2, rm1, cm1] = copy(env2[2, r, c])
end
else
# N-site MPO gate (N ≥ 2)
Expand Down Expand Up @@ -202,10 +208,8 @@ function MPSKit.timestep(
end

"""
time_evolve(
it::TimeEvolver{<:SimpleUpdate};
tol::Float64 = 0.0, check_interval::Int = 500
) -> (psi, env, info)
time_evolve(it; check_interval = 500) -> (psi, env, info)
time_evolve(it, H; tol = 1.0e-8, check_interval = 500) -> (psi, env, info)

Perform time evolution to the end of `TimeEvolver` iterator `it`,
or until convergence of `SUWeight` set by a positive `tol`.
Expand All @@ -215,15 +219,41 @@ or until convergence of `SUWeight` set by a positive `tol`.
- `check_interval` sets the number of iterations between outputs of information.
"""
function MPSKit.time_evolve(
it::TimeEvolver{<:SimpleUpdate};
tol::Float64 = 0.0, check_interval::Int = 500
it::TimeEvolver{<:SimpleUpdate}; check_interval::Int = 500
)
time_start = time()
check_convergence = (tol > 0)
@info "--- Time evolution (simple update), dt = $(it.dt) ---"
if check_convergence
@assert (it.state.psi isa InfinitePEPS) && it.alg.imaginary_time "Only imaginary time evolution of InfinitePEPS allows convergence checking."
env0, time0 = it.state.env, time()
for (psi, env, info) in it
iter = it.state.iter
diff = compare_weights(env0, env)
stop = (iter == it.nstep)
showinfo = (check_interval > 0) &&
((iter % check_interval == 0) || (iter == 1) || stop)
Comment thread
lkdvos marked this conversation as resolved.
time1 = time()
if showinfo
@info "Space of x-weight at [1, 1] = $(space(env[1, 1, 1], 1))"
@info @sprintf("SU iter %-7d: |Δλ| = %.3e. Time = %.3f s/it", iter, diff, time1 - time0)
end
if stop
time_end = time()
@info @sprintf("Time evolution finished in %.2f s", time_end - time_start)
return psi, env, info
else
env0 = env
end
time0 = time()
end
return
end

function MPSKit.time_evolve(
it::TimeEvolver{<:SimpleUpdate, G, S}, H::LocalOperator;
tol::Float64 = 1.0e-8, check_interval::Int = 500
) where {G, S <: SUState{<:InfinitePEPS}}
time_start = time()
@info "--- Time evolution (simple update), dt = $(it.dt) ---"
@assert it.alg.imaginary_time "Only imaginary time evolution of InfinitePEPS allows convergence checking."
env0, time0 = it.state.env, time()
for (psi, env, info) in it
iter = it.state.iter
Expand All @@ -233,16 +263,20 @@ function MPSKit.time_evolve(
((iter % check_interval == 0) || (iter == 1) || stop)
time1 = time()
if showinfo
# TODO: convert to BPEnv instead
ctmenv = CTMRGEnv(env)
energy = real(expectation_value(psi, H, ctmenv)) / prod(size(psi))
@info "Space of x-weight at [1, 1] = $(space(env[1, 1, 1], 1))"
@info @sprintf("SU iter %-7d: |Δλ| = %.3e. Time = %.3f s/it", iter, diff, time1 - time0)
@info @sprintf(
"SU iter %-7d: E ≈ %.5f, |Δλ| = %.3e. Time = %.3f s/it",
iter, energy, diff, time1 - time0
)
end
if check_convergence
if (iter == it.nstep) && (diff >= tol)
@warn "SU: bond weights have not converged."
end
if diff < tol
@info "SU: bond weights have converged."
end
if (iter == it.nstep) && (diff >= tol)
@warn "SU: bond weights have not converged."
end
if diff < tol
@info "SU: bond weights have converged."
end
if stop
time_end = time()
Expand All @@ -269,7 +303,6 @@ algorithm `alg`, time step `dt` for `nstep` number of steps.

- Set `symmetrize_gates = true` for second-order Trotter decomposition.
- Set `tol > 0` to enable convergence check (for imaginary time evolution of iPEPS only).
For other usages it should not be changed.
- Use `t0` to specify the initial time of the evolution.
- `check_interval` sets the interval to output information. Output during the evolution can be turned off by setting `check_interval <= 0`.
- `info` is a NamedTuple containing information of the evolution,
Expand All @@ -281,5 +314,9 @@ function MPSKit.time_evolve(
tol::Float64 = 0.0, t0::Number = 0.0, check_interval::Int = 500
)
it = TimeEvolver(psi0, H, dt, nstep, alg, env0; t0, symmetrize_gates)
return time_evolve(it; tol, check_interval)
return if tol == 0
time_evolve(it; check_interval)
else
time_evolve(it, H; tol, check_interval)
end
end
28 changes: 9 additions & 19 deletions src/algorithms/time_evolution/simpleupdate3site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function _get_cluster(
p = invperm((p1..., p2...))
return (p[1:Np], p[(Np + 1):end])
end
Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm)
Ms = map(sites, open_vaxs, perms) do site, vaxs, perm
s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc))
M = if env === nothing
state[s]
Expand Down Expand Up @@ -190,7 +190,7 @@ function _su_iter!(
# restore virtual arrows in `Ms`
_flip_virtuals!(Ms, flips)
# update env weights
bond_revs = map(zip(sites, Iterators.drop(sites, 1))) do (site1, site2)
bond_revs = map(sites, Iterators.drop(sites, 1)) do site1, site2
_nn_bondrev(site1, site2, (Nr, Nc))
end
for (wt, (bond, rev), flip) in zip(wts, bond_revs, flips)
Expand All @@ -217,24 +217,14 @@ updated by the Trotter evolution MPO.
"""
function _get_cluster_trunc(
trunc::TruncationStrategy, sites::Vector{CartesianIndex{2}},
(Nrow, Ncol)::NTuple{2, Int}
unitcell::NTuple{2, Int}
)
return map(zip(sites, Iterators.drop(sites, 1))) do (site1, site2)
diff = site2 - site1
if diff == CartesianIndex(0, 1)
r, c = mod1(site1[1], Nrow), mod1(site1[2], Ncol)
return truncation_strategy(trunc, 1, r, c)
elseif diff == CartesianIndex(0, -1)
r, c = mod1(site2[1], Nrow), mod1(site2[2], Ncol)
return truncation_strategy(trunc, 1, r, c)
elseif diff == CartesianIndex(1, 0)
r, c = mod1(site2[1], Nrow), mod1(site2[2], Ncol)
return truncation_strategy(trunc, 2, r, c)
elseif diff == CartesianIndex(-1, 0)
r, c = mod1(site1[1], Nrow), mod1(site1[2], Ncol)
return truncation_strategy(trunc, 2, r, c)
else
error("The path `sites` contains a long-range bond.")
return map(sites, Iterators.drop(sites, 1)) do site1, site2
(d, r, c), rev = _nn_bondrev(site1, site2, unitcell)
t = truncation_strategy(trunc, d, r, c)
if rev && isa(t, TruncationSpace)
t = truncspace(flip(t.space)')
end
return t
Comment thread
lkdvos marked this conversation as resolved.
end
end
Loading
Loading