Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
f2f9011
Refactor `NestedAlgorithm` hooks: `initialize_subsolve` + `finalize_s…
mtfishman May 19, 2026
314c294
Strip AIE down to minimal NestedAlgorithm + abstract scaffolding
mtfishman May 19, 2026
e35b1a2
Drop AIE `Problem` / `Algorithm` / `State` / `DefaultState` scaffolding
mtfishman May 19, 2026
b1821ce
Refactor BP into three problem/algorithm/state triples
mtfishman May 19, 2026
fdb3c04
Reorder BP source top-to-bottom
mtfishman May 19, 2026
56b2d82
Move StopWhenConverged + iterate_diff verb into AIE
mtfishman May 19, 2026
4fd8a47
Move `edge` field from `SimpleMessageUpdateAlgorithm` to `MessageUpda…
mtfishman May 19, 2026
81f7b4f
Move BP edge ordering from `BeliefPropagationProblem` to `BeliefPropa…
mtfishman May 19, 2026
6dbf365
Store edges on `BeliefPropagationAlgorithm`, not the sweep algorithm
mtfishman May 19, 2026
0d871c2
Index per-edge BP algorithms by edge; drop `AbstractVector` constraints
mtfishman May 19, 2026
e5a58d0
Rename `beliefpropagationproblem.jl` to `beliefpropagation.jl`
mtfishman May 19, 2026
2e8ee3b
Simplify BP API: single child algorithms + select_* selectors
mtfishman May 20, 2026
c8cafb9
Minor reorg and line-collapse cleanup in `beliefpropagation`
mtfishman May 20, 2026
b0a1e90
Collapse BP message-update AI layer; introduce `MessageUpdateAlgorith…
mtfishman May 20, 2026
3c9c2d8
Add args tuple + AbstractAlgorithm supertype to `select_algorithm`
mtfishman May 20, 2026
5d56f9f
Return mutated cache from `message_update!`; accept `tol` in stopping…
mtfishman May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
version = "0.4.2"
version = "0.4.3"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
304 changes: 106 additions & 198 deletions src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,250 +2,158 @@ module AlgorithmsInterfaceExtensions

import AlgorithmsInterface as AI

# ========================== Patches for AlgorithmsInterface.jl ============================
# ============================ NestedAlgorithm =============================================

abstract type Problem <: AI.Problem end
abstract type Algorithm <: AI.Algorithm end
abstract type State <: AI.State end
abstract type NestedAlgorithm <: AI.Algorithm end

function AI.initialize_state!(
problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs...
)
for (k, v) in pairs(kwargs)
setproperty!(state, k, v)
end
state.iteration = iteration
AI.initialize_state!(
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
# Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it
# returns the `(subproblem, subalgorithm, substate)` tuple that the next
# inner `AI.solve!` call consumes. The default `finalize_substate!` copies
# the substate's iterate back into the parent state; subtypes can override
# when more is required.
function initialize_subsolve(
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State
)
return state
return throw(MethodError(initialize_subsolve, (problem, algorithm, state)))
end

function AI.initialize_state(
problem::Problem, algorithm::Algorithm; iterate, kwargs...
)
stopping_criterion_state = AI.initialize_state(
problem, algorithm, algorithm.stopping_criterion; iterate
function finalize_substate!(
problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State
)
return DefaultState(; iterate, stopping_criterion_state, kwargs...)
end

# ============================ DefaultState ================================================

@kwdef mutable struct DefaultState{
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
} <: State
iterate::Iterate
iteration::Int = 0
stopping_criterion_state::StoppingCriterionState
state.iterate = substate.iterate
return state
end

# ============================ increment! ==================================================

# Custom version of `increment!` that also takes the problem and algorithm as arguments.
function AI.increment!(problem::Problem, algorithm::Algorithm, state::State)
return AI.increment!(state)
function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State)
subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state)
AI.solve!(subproblem, subalgorithm, substate)
finalize_substate!(problem, algorithm, state, substate)
return state
end

# ============================ AlgorithmIterator ===========================================
# ============================ NestedState =================================================

abstract type AlgorithmIterator end
# State that wraps an inner `substate` and forwards `:iterate` accesses to it,
# so the inner-loop iterate is shared without duplicating storage on the outer
# state. Subtypes must store the inner state as a field named `substate`.
abstract type NestedState <: AI.State end

function algorithm_iterator(
problem::Problem, algorithm::Algorithm, state::State
)
return DefaultAlgorithmIterator(problem, algorithm, state)
end

function AI.is_finished!(iterator::AlgorithmIterator)
return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state)
end
function AI.is_finished(iterator::AlgorithmIterator)
return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state)
# Use `getfield` on the right-hand side so future edits to this forwarder
# can't accidentally recurse through the overload.
function Base.getproperty(state::NestedState, name::Symbol)
name === :iterate && return getfield(state, :substate).iterate
return getfield(state, name)
end
function AI.increment!(iterator::AlgorithmIterator)
return AI.increment!(iterator.problem, iterator.algorithm, iterator.state)
function Base.setproperty!(state::NestedState, name::Symbol, value)
name === :iterate && return (getfield(state, :substate).iterate = value)
return setfield!(state, name, value)
end
function AI.step!(iterator::AlgorithmIterator)
return AI.step!(iterator.problem, iterator.algorithm, iterator.state)
end
function Base.iterate(iterator::AlgorithmIterator, init = nothing)
AI.is_finished!(iterator) && return nothing
AI.increment!(iterator)
AI.step!(iterator)
return iterator.state, nothing
function Base.propertynames(state::NestedState)
return (fieldnames(typeof(state))..., :iterate)
end

struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
problem::Problem
algorithm::Algorithm
state::State
end
# ============================ select_algorithm / default_algorithm ========================

# ============================ with_algorithmlogger ========================================
# Like `MatrixAlgebraKit.select_algorithm` / `default_algorithm`, but
# selection-relevant inputs are packed into an `args` tuple so the value
# and type domains stay disjoint: `(1.2,)` vs `Tuple{Float64}`. Strategy
# types subtype `AbstractAlgorithm` so the passthrough overload is generic.
abstract type AbstractAlgorithm end

# Allow passing functions, not just CallbackActions.
@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...)
return AI.with_algorithmlogger(f, args...)
function default_algorithm(f, ::Type{Args}; kwargs...) where {Args <: Tuple}
return throw(MethodError(default_algorithm, (f, Args)))
end
@inline function with_algorithmlogger(f, args::Pair{Symbol}...)
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
function default_algorithm(f, args::Tuple; kwargs...)
return default_algorithm(f, typeof(args); kwargs...)
end

# ============================ NestedAlgorithm =============================================

abstract type NestedAlgorithm <: Algorithm end

nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...)
function nested_algorithm(f::Function, iterable; kwargs...)
return DefaultNestedAlgorithm(f, iterable; kwargs...)
function select_algorithm(f, alg, args::Tuple; kwargs...)
return select_algorithm(f, alg, typeof(args); kwargs...)
end

max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)

function get_subproblem(
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State
function select_algorithm(f, ::Nothing, ::Type{Args}; kwargs...) where {Args <: Tuple}
return default_algorithm(f, Args; kwargs...)
end
function select_algorithm(f, alg::NamedTuple, ::Type{Args}; kwargs...) where {Args <: Tuple}
isempty(kwargs) || throw(
ArgumentError(
"Additional keyword arguments are not allowed when `alg` is a `NamedTuple`."
)
)
subproblem = problem
subalgorithm = algorithm.algorithms[state.iteration]
substate = AI.initialize_state(subproblem, subalgorithm; state.iterate)
return subproblem, subalgorithm, substate
return default_algorithm(f, Args; alg...)
end

function set_substate!(
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State
function select_algorithm(f, alg::AbstractAlgorithm, ::Type{<:Tuple}; kwargs...)
isempty(kwargs) || throw(
ArgumentError(
"Additional keyword arguments are not allowed when `alg` is an `AbstractAlgorithm` instance."
)
)
state.iterate = substate.iterate
return state
return alg
end

function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State)
# Get the subproblem, subalgorithm, and substate.
subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state)

# Solve the subproblem with the subalgorithm.
AI.solve!(subproblem, subalgorithm, substate)

# Update the state with the substate.
set_substate!(problem, algorithm, state, substate)
# ============================ StopWhenConverged ===========================================

return state
# Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`.
# Concrete iterate types must supply an `iterate_diff` method.
function iterate_diff(a, b)
return throw(MethodError(iterate_diff, (a, b)))
end

#=
DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})

An algorithm that consists of running an algorithm at each iteration
from a list of stored algorithms.
=#
@kwdef struct DefaultNestedAlgorithm{
ChildAlgorithm <: Algorithm,
Algorithms <: AbstractVector{ChildAlgorithm},
StoppingCriterion <: AI.StoppingCriterion,
} <: NestedAlgorithm
algorithms::Algorithms
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
@kwdef struct StopWhenConverged <: AI.StoppingCriterion
tol::Float64
end
function DefaultNestedAlgorithm(f::Function, iterable; kwargs...)
return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...)
end

# ============================ FlattenedAlgorithm ==========================================

# Flatten a nested algorithm.
abstract type FlattenedAlgorithm <: Algorithm end
abstract type FlattenedAlgorithmState <: State end

function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...)
return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...)
@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState
delta::Float64 = Inf
at_iteration::Int = -1
previous_iterate::Iterate
end

function AI.initialize_state(
problem::Problem, algorithm::FlattenedAlgorithm; kwargs...
)
stopping_criterion_state = AI.initialize_state(
problem, algorithm, algorithm.stopping_criterion
)
return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
end
function AI.increment!(
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
)
# Increment the total iteration count.
state.iteration += 1
# TODO: Use `is_finished!` instead?
if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration])
# We're on the last iteration of the child algorithm, so move to the next
# child algorithm.
state.parent_iteration += 1
state.child_iteration = 1
else
# Iterate the child algorithm.
state.child_iteration += 1
end
return state
function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate)
return StopWhenConvergedState(; previous_iterate = copy(iterate))
end
function AI.step!(
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState
)
algorithm_sweep = algorithm.algorithms[state.parent_iteration]
state_sweep = AI.initialize_state(
problem, algorithm_sweep;
state.iterate, iteration = state.child_iteration

function AI.initialize_state!(
::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState
)
AI.step!(problem, algorithm_sweep, state_sweep)
state.iterate = state_sweep.iterate
return state
st.delta = Inf
return st
end

@kwdef struct DefaultFlattenedAlgorithm{
ChildAlgorithm <: Algorithm,
Algorithms <: AbstractVector{ChildAlgorithm},
StoppingCriterion <: AI.StoppingCriterion,
} <: FlattenedAlgorithm
algorithms::Algorithms
stopping_criterion::StoppingCriterion =
AI.StopAfterIteration(sum(max_iterations, algorithms))
end
function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
end
function AI.is_finished!(
problem::AI.Problem,
algorithm::AI.Algorithm,
state::AI.State,
c::StopWhenConverged,
st::StopWhenConvergedState
)
iterate = state.iterate
previous_iterate = st.previous_iterate

@kwdef mutable struct DefaultFlattenedAlgorithmState{
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
} <: FlattenedAlgorithmState
iterate::Iterate
iteration::Int = 0
parent_iteration::Int = 1
child_iteration::Int = 0
stopping_criterion_state::StoppingCriterionState
end
delta = iterate_diff(iterate, previous_iterate)

# ============================ NonIterativeAlgorithm =======================================
st.previous_iterate = copy(iterate)

# Algorithm that only performs a single step.
abstract type NonIterativeAlgorithm <: Algorithm end
abstract type NonIterativeAlgorithmState <: State end
# delta = 0 initially, so skip this the first time.
state.iteration == 0 && return false

function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...)
return DefaultNonIterativeAlgorithmState(; kwargs...)
end
st.delta = delta

function AI.initialize_state!(
problem::Problem,
algorithm::NonIterativeAlgorithm,
state::NonIterativeAlgorithmState
)
return state
end
if AI.is_finished(problem, algorithm, state, c, st)
st.at_iteration = state.iteration
return true
end

function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State)
return throw(MethodError(AI.solve_loop!, (problem, algorithm, state)))
return false
end

@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <:
NonIterativeAlgorithmState
iterate::Iterate
function AI.is_finished(
::AI.Problem,
::AI.Algorithm,
::AI.State,
c::StopWhenConverged,
st::StopWhenConvergedState
)
return st.delta < c.tol
end

end
4 changes: 1 addition & 3 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ include("abstracttensornetwork.jl")
include("tensornetwork.jl")
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
include("contract_network.jl")
include("sweeping/utils.jl")
include("sweeping/eigenproblem.jl")

include("beliefpropagation/messagecache.jl")
include("beliefpropagation/beliefpropagationproblem.jl")
include("beliefpropagation/beliefpropagation.jl")

end
Loading
Loading