diff --git a/Project.toml b/Project.toml index 301fb17..c8358d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.4.2" +version = "0.4.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index d9edb0d..a95e0e0 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -2,250 +2,158 @@ module AlgorithmsInterfaceExtensions import AlgorithmsInterface as AI -# ========================== Patches for AlgorithmsInterface.jl ============================ +# ============================ NestedAlgorithm ============================================= -abstract type Problem <: AI.Problem end -abstract type Algorithm <: AI.Algorithm end -abstract type State <: AI.State end +abstract type NestedAlgorithm <: AI.Algorithm end -function AI.initialize_state!( - problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... - ) - for (k, v) in pairs(kwargs) - setproperty!(state, k, v) - end - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state +# Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it +# returns the `(subproblem, subalgorithm, substate)` tuple that the next +# inner `AI.solve!` call consumes. The default `finalize_substate!` copies +# the substate's iterate back into the parent state; subtypes can override +# when more is required. +function initialize_subsolve( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State ) - return state + return throw(MethodError(initialize_subsolve, (problem, algorithm, state))) end -function AI.initialize_state( - problem::Problem, algorithm::Algorithm; iterate, kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate +function finalize_substate!( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State ) - return DefaultState(; iterate, stopping_criterion_state, kwargs...) -end - -# ============================ DefaultState ================================================ - -@kwdef mutable struct DefaultState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState + state.iterate = substate.iterate + return state end -# ============================ increment! ================================================== - -# Custom version of `increment!` that also takes the problem and algorithm as arguments. -function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) - return AI.increment!(state) +function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) + subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) + AI.solve!(subproblem, subalgorithm, substate) + finalize_substate!(problem, algorithm, state, substate) + return state end -# ============================ AlgorithmIterator =========================================== +# ============================ NestedState ================================================= -abstract type AlgorithmIterator end +# State that wraps an inner `substate` and forwards `:iterate` accesses to it, +# so the inner-loop iterate is shared without duplicating storage on the outer +# state. Subtypes must store the inner state as a field named `substate`. +abstract type NestedState <: AI.State end -function algorithm_iterator( - problem::Problem, algorithm::Algorithm, state::State - ) - return DefaultAlgorithmIterator(problem, algorithm, state) -end - -function AI.is_finished!(iterator::AlgorithmIterator) - return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.is_finished(iterator::AlgorithmIterator) - return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) +# Use `getfield` on the right-hand side so future edits to this forwarder +# can't accidentally recurse through the overload. +function Base.getproperty(state::NestedState, name::Symbol) + name === :iterate && return getfield(state, :substate).iterate + return getfield(state, name) end -function AI.increment!(iterator::AlgorithmIterator) - return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) +function Base.setproperty!(state::NestedState, name::Symbol, value) + name === :iterate && return (getfield(state, :substate).iterate = value) + return setfield!(state, name, value) end -function AI.step!(iterator::AlgorithmIterator) - return AI.step!(iterator.problem, iterator.algorithm, iterator.state) -end -function Base.iterate(iterator::AlgorithmIterator, init = nothing) - AI.is_finished!(iterator) && return nothing - AI.increment!(iterator) - AI.step!(iterator) - return iterator.state, nothing +function Base.propertynames(state::NestedState) + return (fieldnames(typeof(state))..., :iterate) end -struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator - problem::Problem - algorithm::Algorithm - state::State -end +# ============================ select_algorithm / default_algorithm ======================== -# ============================ with_algorithmlogger ======================================== +# Like `MatrixAlgebraKit.select_algorithm` / `default_algorithm`, but +# selection-relevant inputs are packed into an `args` tuple so the value +# and type domains stay disjoint: `(1.2,)` vs `Tuple{Float64}`. Strategy +# types subtype `AbstractAlgorithm` so the passthrough overload is generic. +abstract type AbstractAlgorithm end -# Allow passing functions, not just CallbackActions. -@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) - return AI.with_algorithmlogger(f, args...) +function default_algorithm(f, ::Type{Args}; kwargs...) where {Args <: Tuple} + return throw(MethodError(default_algorithm, (f, Args))) end -@inline function with_algorithmlogger(f, args::Pair{Symbol}...) - return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) +function default_algorithm(f, args::Tuple; kwargs...) + return default_algorithm(f, typeof(args); kwargs...) end -# ============================ NestedAlgorithm ============================================= - -abstract type NestedAlgorithm <: Algorithm end - -nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) -function nested_algorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(f, iterable; kwargs...) +function select_algorithm(f, alg, args::Tuple; kwargs...) + return select_algorithm(f, alg, typeof(args); kwargs...) end - -max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) - -function get_subproblem( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State +function select_algorithm(f, ::Nothing, ::Type{Args}; kwargs...) where {Args <: Tuple} + return default_algorithm(f, Args; kwargs...) +end +function select_algorithm(f, alg::NamedTuple, ::Type{Args}; kwargs...) where {Args <: Tuple} + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is a `NamedTuple`." + ) ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + return default_algorithm(f, Args; alg...) end - -function set_substate!( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State +function select_algorithm(f, alg::AbstractAlgorithm, ::Type{<:Tuple}; kwargs...) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is an `AbstractAlgorithm` instance." + ) ) - state.iterate = substate.iterate - return state + return alg end -function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) - # Get the subproblem, subalgorithm, and substate. - subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) - - # Solve the subproblem with the subalgorithm. - AI.solve!(subproblem, subalgorithm, substate) - - # Update the state with the substate. - set_substate!(problem, algorithm, state, substate) +# ============================ StopWhenConverged =========================================== - return state +# Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`. +# Concrete iterate types must supply an `iterate_diff` method. +function iterate_diff(a, b) + return throw(MethodError(iterate_diff, (a, b))) end -#= - DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) - -An algorithm that consists of running an algorithm at each iteration -from a list of stored algorithms. -=# -@kwdef struct DefaultNestedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 end -function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) -end - -# ============================ FlattenedAlgorithm ========================================== -# Flatten a nested algorithm. -abstract type FlattenedAlgorithm <: Algorithm end -abstract type FlattenedAlgorithmState <: State end - -function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate end -function AI.initialize_state( - problem::Problem, algorithm::FlattenedAlgorithm; kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) -end -function AI.increment!( - problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState - ) - # Increment the total iteration count. - state.iteration += 1 - # TODO: Use `is_finished!` instead? - if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) - # We're on the last iteration of the child algorithm, so move to the next - # child algorithm. - state.parent_iteration += 1 - state.child_iteration = 1 - else - # Iterate the child algorithm. - state.child_iteration += 1 - end - return state +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) end -function AI.step!( - problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState - ) - algorithm_sweep = algorithm.algorithms[state.parent_iteration] - state_sweep = AI.initialize_state( - problem, algorithm_sweep; - state.iterate, iteration = state.child_iteration + +function AI.initialize_state!( + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) - AI.step!(problem, algorithm_sweep, state_sweep) - state.iterate = state_sweep.iterate - return state + st.delta = Inf + return st end -@kwdef struct DefaultFlattenedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: FlattenedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = - AI.StopAfterIteration(sum(max_iterations, algorithms)) -end -function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) -end +function AI.is_finished!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate -@kwdef mutable struct DefaultFlattenedAlgorithmState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: FlattenedAlgorithmState - iterate::Iterate - iteration::Int = 0 - parent_iteration::Int = 1 - child_iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end + delta = iterate_diff(iterate, previous_iterate) -# ============================ NonIterativeAlgorithm ======================================= + st.previous_iterate = copy(iterate) -# Algorithm that only performs a single step. -abstract type NonIterativeAlgorithm <: Algorithm end -abstract type NonIterativeAlgorithmState <: State end + # delta = 0 initially, so skip this the first time. + state.iteration == 0 && return false -function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) - return DefaultNonIterativeAlgorithmState(; kwargs...) -end + st.delta = delta -function AI.initialize_state!( - problem::Problem, - algorithm::NonIterativeAlgorithm, - state::NonIterativeAlgorithmState - ) - return state -end + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true + end -function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State) - return throw(MethodError(AI.solve_loop!, (problem, algorithm, state))) + return false end -@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: - NonIterativeAlgorithmState - iterate::Iterate +function AI.is_finished( + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol end end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 0b2b898..41ce78e 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -12,10 +12,8 @@ include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("sweeping/utils.jl") -include("sweeping/eigenproblem.jl") include("beliefpropagation/messagecache.jl") -include("beliefpropagation/beliefpropagationproblem.jl") +include("beliefpropagation/beliefpropagation.jl") end diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl new file mode 100644 index 0000000..d6dfabc --- /dev/null +++ b/src/beliefpropagation/beliefpropagation.jl @@ -0,0 +1,255 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using .AlgorithmsInterfaceExtensions: StopWhenConverged, iterate_diff +using BackendSelection: @Algorithm_str, Algorithm +using DataGraphs: edge_data +using Graphs: AbstractEdge, edges, edgetype, has_edge, vertices +using LinearAlgebra: norm, normalize +using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using NamedGraphs.PartitionedGraphs: quotientvertices + +# === Top-level user entry point === + +default_beliefpropagation_edges(graph) = forest_cover_edge_sequence(graph) + +select_beliefpropagation_stopping_criterion(c::AI.StoppingCriterion) = c +function select_beliefpropagation_stopping_criterion(::Nothing) + return throw( + ArgumentError( + "`stopping_criterion` must be specified, e.g.\n" * + " `stopping_criterion = (; maxiter = 10)`,\n" * + " `stopping_criterion = (; maxiter = 10, tol = 1.0e-10)`, or\n" * + " `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`." + ) + ) +end +function select_beliefpropagation_stopping_criterion(kwargs::NamedTuple) + return select_beliefpropagation_stopping_criterion(; kwargs...) +end +function select_beliefpropagation_stopping_criterion(; + maxiter = nothing, tol = nothing, kwargs... + ) + if !isempty(kwargs) + throw( + ArgumentError( + "Unrecognized `stopping_criterion` kwargs: $(keys(kwargs)). " * + "Supported: `maxiter`, `tol`." + ) + ) + end + if isnothing(maxiter) && isnothing(tol) + throw( + ArgumentError("At least one of `maxiter` or `tol` must be specified.") + ) + end + criterion = nothing + if !isnothing(maxiter) + criterion = AI.StopAfterIteration(maxiter) + end + if !isnothing(tol) + converged = StopWhenConverged(; tol) + criterion = isnothing(criterion) ? converged : criterion | converged + end + return criterion +end + +function beliefpropagation( + factors, messages; + edges = default_beliefpropagation_edges(factors), + stopping_criterion = nothing, + message_update_algorithm = nothing + ) + problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) + + # No concrete `edge` value here, so the args tuple uses `edgetype(factors)`. + message_update_algorithm = AIE.select_algorithm( + message_update!, + message_update_algorithm, + Tuple{typeof(cache), typeof(factors), edgetype(factors)} + ) + subalgorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(edges)) + ) + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; edges, subalgorithm, stopping_criterion) + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +end + +# === Layer 1: BP outer loop (iterative) === + +struct BeliefPropagationProblem{Factors} <: AI.Problem + factors::Factors +end + +@kwdef struct BeliefPropagationAlgorithm{ + Edges, + Subalgorithm <: AI.Algorithm, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + edges::Edges + subalgorithm::Subalgorithm + stopping_criterion::StoppingCriterion +end + +@kwdef mutable struct BeliefPropagationState{ + Substate <: AI.State, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AIE.NestedState + substate::Substate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm; + iterate, iteration::Int = 0 + ) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) + substate = AI.initialize_state(subproblem, algorithm.subalgorithm; iterate) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationState(; iteration, stopping_criterion_state, substate) +end + +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AIE.initialize_subsolve( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState + ) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) + return subproblem, algorithm.subalgorithm, state.substate +end + +# === Layer 2: one sweep over edges (iterative) === + +struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem + factors::Factors + edges::Edges +end + +@kwdef struct BeliefPropagationSweepAlgorithm{ + MessageUpdateAlgorithm, + StoppingCriterion <: AI.StoppingCriterion, + } <: AI.Algorithm + message_update_algorithm::MessageUpdateAlgorithm = SimpleMessageUpdate() + stopping_criterion::StoppingCriterion +end + +@kwdef mutable struct BeliefPropagationSweepState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +function AI.initialize_state( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm; + iterate, iteration::Int = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationSweepState(; iterate, iteration, stopping_criterion_state) +end + +function AI.initialize_state!( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.step!( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState + ) + edge = problem.edges[state.iteration] + message_update!( + algorithm.message_update_algorithm, state.iterate, problem.factors, edge + ) + return state +end + +# === Layer 3: single-edge message update strategy === + +# Strategy interface: a `MessageUpdateAlgorithm` defines how a single +# message is computed and written back into the message store. Plug in a +# new strategy by subtyping `MessageUpdateAlgorithm` and overloading +# `message_update!(strategy, cache, factors, edge)`. +abstract type MessageUpdateAlgorithm <: AIE.AbstractAlgorithm end + +function message_update! end + +# `args` tuple mirrors the `message_update!(cache, factors, edge)` call shape. +function AIE.default_algorithm(::typeof(message_update!), ::Type{<:Tuple}; kwargs...) + return SimpleMessageUpdate(; kwargs...) +end + +# Convenience entry: pick the strategy via `AIE.select_algorithm` +# (accepts either `alg = ::MessageUpdateAlgorithm` / `::NamedTuple`, or flat +# kwargs forwarded to the default algorithm), then dispatch. +function message_update!(cache, factors, edge; alg = nothing, kwargs...) + return message_update!( + AIE.select_algorithm(message_update!, alg, (cache, factors, edge); kwargs...), + cache, factors, edge + ) +end + +@kwdef struct SimpleMessageUpdate{ContractionAlg} <: MessageUpdateAlgorithm + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" +end + +function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) + messages = collect(incoming_messages(cache, edge)) + factor = factors[src(edge)] + + new_message = contract_network([messages; [factor]]; algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end + end + + cache[edge] = new_message + return cache +end + +# === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === + +function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + end +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl deleted file mode 100644 index 004e449..0000000 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ /dev/null @@ -1,220 +0,0 @@ -import .AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI -using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: edge_data -using Graphs: AbstractEdge, edges, has_edge, vertices -using LinearAlgebra: norm, normalize -using NamedDimsArrays: AbstractNamedDimsArray -using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph -using NamedGraphs.PartitionedGraphs: quotientvertices - -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 -end - -@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState - delta::Float64 = Inf - at_iteration::Int = -1 - previous_iterate::Iterate -end - -function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged; iterate) - return StopWhenConvergedState(; previous_iterate = copy(iterate)) -end - -function AI.initialize_state!( - ::AIE.Problem, - ::AIE.Algorithm, - ::StopWhenConverged, - st::StopWhenConvergedState - ) - st.delta = Inf - return st -end - -function AI.is_finished!( - problem::AIE.Problem, - algorithm::AIE.Algorithm, - state::AIE.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - iterate = state.iterate - previous_iterate = st.previous_iterate - - delta = iterate_diff(iterate, previous_iterate) - - st.previous_iterate = copy(iterate) - - # maxdiff = 0.0 initially, so skip this the first time. - state.iteration == 0 && return false - - st.delta = delta - - if AI.is_finished(problem, algorithm, state, c, st) - st.at_iteration = state.iteration - return true - end - - return false -end - -function AI.is_finished( - ::AIE.Problem, - ::AIE.Algorithm, - ::AIE.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - return st.delta < c.tol -end - -struct BeliefPropagationProblem{Factors} <: AIE.Problem - factors::Factors -end - -function iterate_diff( - cache1::MessageCache, - cache2::MessageCache - ) - return maximum(edges(cache1)) do edge - m1 = cache1[edge] - m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) - end -end - -@kwdef struct BeliefPropagation{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagation(f::Function, niterations::Int; kwargs...) - return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) -end - -struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} - edge::E - kwargs::Kwargs -end - -function SimpleMessageUpdate( - edge; - normalize = true, - contraction_alg = Algorithm"exact", - kwargs... - ) - return SimpleMessageUpdate( - edge, - (; normalize, contraction_alg, kwargs...) - ) -end - -function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) - if name in (:edge, :kwargs) - return getfield(alg, name) - else - return getproperty(getfield(alg, :kwargs), name) - end -end - -AI.initialize_state(::BeliefPropagationProblem, ::SimpleMessageUpdate; iterate) = iterate - -struct BeliefPropagationSweep{ - ChildAlgorithm, Algorithms <: AbstractVector{ChildAlgorithm}, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::AI.StopAfterIteration - function BeliefPropagationSweep(; algorithms) - stopping_criterion = AI.StopAfterIteration(length(algorithms)) - return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) - end -end - -function BeliefPropagationSweep(f::Function, edges) - return BeliefPropagationSweep(; algorithms = f.(edges)) -end - -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - state::AIE.DefaultState, - cache::MessageCache - ) - state.iterate = cache - - return state -end - -function AI.solve!( - problem::BeliefPropagationProblem, - algorithm::SimpleMessageUpdate, - cache::MessageCache - ) - edge = algorithm.edge - - messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] - - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) - - if algorithm.normalize - message_norm = sum(new_message) - if !iszero(message_norm) - new_message /= message_norm - end - end - - cache[edge] = new_message - - return cache -end - -function beliefpropagation( - factors, messages; - edges = nothing, - maxiter = is_tree(factors) ? 1 : nothing, - stopping_criterion = nothing, - kwargs... - ) - if isnothing(maxiter) - throw( - ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when - `stopping_criterion` is provided." - ) - ) - end - - cache = MessageCache(messages) - problem = BeliefPropagationProblem(factors) - - ## Algorithm construction: - - edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - - base_stopping_criterion = AI.StopAfterIteration(maxiter) - - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion - end - - stopping_criterion = base_stopping_criterion - - extended_kwargs = extend_columns((; kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) - - algorithm = BeliefPropagation(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweep(edges) do edge - return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) - end - end - - ## - - return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) -end diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl deleted file mode 100644 index 8fefbd0..0000000 --- a/src/sweeping/eigenproblem.jl +++ /dev/null @@ -1,44 +0,0 @@ -import .AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI - -function dmrg(operator, algorithm, state) - problem = EigenProblem(operator) - return AI.solve(problem, algorithm; iterate = state).iterate -end -function dmrg(operator, state; kwargs...) - problem = EigenProblem(operator) - algorithm = select_algorithm(dmrg, operator, state; kwargs...) - return AI.solve(problem, algorithm; iterate = state).iterate -end - -# TODO: Allow specifying the region algorithm type? -function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) - extended_kwargs = extend_columns((; kwargs...), nsweeps) - region_kwargs = rows(extended_kwargs) - return AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion(regions[j]; region_kwargs[i]...) - end - end -end -#= - EigenProblem(operator) - -Represents the problem we are trying to solve and minimal algorithm-independent -information, so for an eigenproblem it is the operator we want the eigenvector of. -=# -struct EigenProblem{Operator} <: AIE.Problem - operator::Operator -end - -struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) - -function AI.solve!( - problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... - ) - return error("EigsolveRegion step for EigenProblem not implemented yet.") -end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl deleted file mode 100644 index 39e09e4..0000000 --- a/src/sweeping/utils.jl +++ /dev/null @@ -1,12 +0,0 @@ -# Utility functions for processing keyword arguments. -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 6f80527..f580826 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -1,430 +1,139 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset +using Test: @test, @test_throws, @testset -# Define test problems, algorithms, and states for testing -struct TestProblem <: AIE.Problem - data::Vector{Float64} +# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms +# and picks them by iteration index. Mirrors how `BeliefPropagationAlgorithm` +# shapes itself on top of the minimal `AIE.NestedAlgorithm`. +struct TestProblem <: AI.Problem end + +@kwdef struct TestChildAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AI.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(2) end -@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +@kwdef mutable struct TestChildState{SCState <: AI.StoppingCriterionState} <: AI.State + iterate::Vector{Float64} + iteration::Int = 0 + stopping_criterion_state::SCState end -@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) +function AI.initialize_state( + problem::TestProblem, algorithm::TestChildAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) end -function AI.step!( - problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState +function AI.initialize_state!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state ) - state.iterate .+= 1 # Simple increment step return state end +function AI.increment!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState + ) + return AI.increment!(state) +end + function AI.step!( - problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState ) - state.iterate .+= 2 # Different increment step + state.iterate .+= 1 return state end -@testset "AlgorithmsInterfaceExtensions" begin - @testset "DefaultState" begin - # Test DefaultState construction - iterate = [1.0, 2.0, 3.0] - stopping_criterion_state = AI.initialize_state( - TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion - ) - state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) - @test state.iterate == iterate - @test state.iteration == 0 - @test state.stopping_criterion_state isa AI.StoppingCriterionState - - # Test DefaultState with custom iteration - state.iteration = 5 - @test state.iteration == 5 - end - - @testset "initialize_state!" begin - # Test initialize_state! with iterate kwarg - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; - iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state - ) - AI.initialize_state!(problem, algorithm, state) - @test state.iterate == [0.0, 0.0] - @test state.iteration == 0 - @test state.stopping_criterion_state == stopping_criterion_state - end - - @testset "initialize_state" begin - # Test initialize_state without exclamation - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - - state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) - @test state isa AIE.DefaultState - @test state.iteration == 0 - end - - @testset "increment!" begin - # Test increment! with problem and algorithm - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - # Increment and verify iteration counter increases - AI.increment!(problem, algorithm, state) - @test state.iteration == 1 - - AI.increment!(problem, algorithm, state) - @test state.iteration == 2 - end - - @testset "solve! and solve" begin - # Test solve! with simple problem - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) - - initial_iterate = [10.0, 20.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - - # Solve with custom initial iterate - initial_iterate = [5.0, 10.0] - final_iterate = AI.solve!( - problem, algorithm, state; iterate = copy(initial_iterate) - ) - - @test state.iteration == 3 - @test final_iterate == state.iterate - # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] - @test state.iterate ≈ [8.0, 13.0] - - # Test solve without exclamation - problem2 = TestProblem([1.0, 2.0]) - algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate2 = [5.0, 10.0] - - final_iterate2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) - @test final_iterate2 ≈ [7.0, 12.0] - end - - @testset "DefaultAlgorithmIterator" begin - # Test algorithm iterator creation - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - @test iterator isa AIE.DefaultAlgorithmIterator - @test iterator.problem === problem - @test iterator.algorithm === algorithm - @test iterator.state === state - - # Test iteration interface - @test !AI.is_finished!(iterator) - - # Step through iterator - state_out, _ = iterate(iterator) - @test state_out.iteration == 1 - @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! - - state_out, _ = iterate(iterator) - @test state_out.iteration == 2 - - @test AI.is_finished!(iterator) - end - - @testset "DefaultNestedAlgorithm" begin - # Test creating nested algorithm with function - nested_alg = AIE.nested_algorithm(3) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - @test nested_alg isa AIE.DefaultNestedAlgorithm - @test length(nested_alg.algorithms) == 3 - @test AIE.max_iterations(nested_alg) == 3 - - # Test stepping through nested algorithm - problem = TestProblem([1.0, 2.0]) - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - initial_iterate = [0.0, 0.0] - AI.solve!( - problem, nested_alg, state; iterate = copy(initial_iterate) - ) - - @test state.iteration == 3 - # Each nested algorithm runs once with 2 steps, incrementing by 2 - # Total: 3 algorithms × 2 iterations × 2 increment = 12 - @test state.iterate ≈ [12.0, 12.0] - end - - @testset "NestedAlgorithm basic tests" begin - # Test basic nested algorithm functionality - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - problem = TestProblem([1.0, 2.0]) - - # Test state initialization - state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - - @test state_nested isa AIE.DefaultState - @test state_nested.iteration == 0 - @test AIE.max_iterations(nested_alg) == 2 - end - - @testset "increment! for nested algorithms" begin - # Test increment! logic for nested algorithm state - problem = TestProblem([1.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; - iterate = [0.0], - stopping_criterion_state = stopping_criterion_state - ) - - # Test progression through iterations - @test state.iteration == 0 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 1 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 2 - end - - @testset "get_subproblem and set_substate!" begin - # Test get_subproblem - problem = TestProblem([1.0, 2.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; - iterate = [5.0, 10.0], - iteration = 1, - stopping_criterion_state - ) - - subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) - @test subproblem === problem - @test subalgorithm === nested_alg.algorithms[1] - @test substate.iterate ≈ [5.0, 10.0] - - # Test set_substate! - new_substate = AIE.DefaultState(; - iterate = [100.0, 200.0], - substate.stopping_criterion_state - ) - AIE.set_substate!(problem, nested_alg, state, new_substate) - @test state.iterate ≈ [100.0, 200.0] - end - - @testset "DefaultFlattenedAlgorithm" begin - # Create nested algorithms that support max_iterations - nested_algs = map(1:3) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each - ) - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 3 - - # Test state initialization - problem = TestProblem([1.0, 2.0]) - state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - @test state_flat isa AIE.DefaultFlattenedAlgorithmState - @test state_flat.iteration == 0 - @test state_flat.parent_iteration == 1 - @test state_flat.child_iteration == 0 - end - - @testset "DefaultFlattenedAlgorithmState increment!" begin - # Create nested algorithms for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) - ) - - problem = TestProblem([1.0]) - stopping_criterion_state = AI.initialize_state( - problem, flattened_alg, flattened_alg.stopping_criterion - ) - state = AIE.DefaultFlattenedAlgorithmState(; - iterate = [0.0], - stopping_criterion_state = stopping_criterion_state - ) - - # Test initial state - @test state.iteration == 0 - @test state.parent_iteration == 1 - @test state.child_iteration == 0 - - # First increment - should increment child_iteration - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 1 - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - # Second increment - should increment child_iteration again - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 2 - @test state.parent_iteration == 2 # Should move to next parent - @test state.child_iteration == 1 - end - - @testset "FlattenedAlgorithm step!" begin - # Test individual step! calls for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) - ) - - problem = TestProblem([1.0, 2.0]) - state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - # Manually step through to test step! functionality - AI.increment!(problem, flattened_alg, state) - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - AI.step!(problem, flattened_alg, state) - # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 - @test state.iterate ≈ [4.0, 4.0] - end - - @testset "flattened_algorithm helper" begin - # Test the flattened_algorithm helper function - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - # Using the helper function - flattened_alg = AIE.flattened_algorithm(2) do i - AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 2 - end - - @testset "AlgorithmIterator is_finished (without !)" begin - # Test is_finished without mutation - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Before any iterations - @test !AI.is_finished(iterator) +@kwdef struct TestNestedAlgorithm{ + ChildAlgorithm <: AI.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end - # Run the algorithm - AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) +# Reuse the child-state shape for the parent algorithm too. +function AI.initialize_state( + problem::TestProblem, algorithm::TestNestedAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) +end - # After completion - @test AI.is_finished(iterator) +function AI.initialize_state!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end - @testset "AlgorithmIterator step!" begin - # Test step! method for iterator - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Step the iterator - AI.step!(iterator) - @test iterator.state.iterate ≈ [1.0, 1.0] +function AI.increment!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState + ) + return AI.increment!(state) +end - AI.step!(iterator) - @test iterator.state.iterate ≈ [2.0, 2.0] - end +function AIE.initialize_subsolve( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end - @testset "NestedAlgorithm with different sub-algorithms" begin - # Test nested algorithm with varying sub-algorithms - nested_alg = AIE.DefaultNestedAlgorithm(; +@testset "AlgorithmsInterfaceExtensions" begin + @testset "NestedAlgorithm defaults" begin + # The bare `initialize_subsolve` default throws a `MethodError`, + # forcing concrete subtypes to provide their own override. + problem = TestProblem() + algorithm = TestChildAlgorithm() + state = AI.initialize_state(problem, algorithm; iterate = [0.0]) + @test_throws MethodError AIE.initialize_subsolve(problem, algorithm, state) + + # `finalize_substate!` copies the substate's iterate back into the + # parent state. + substate = AI.initialize_state(problem, algorithm; iterate = [42.0]) + AIE.finalize_substate!(problem, algorithm, state, substate) + @test state.iterate == [42.0] + end + + @testset "TestNestedAlgorithm" begin + problem = TestProblem() + nested_alg = TestNestedAlgorithm(; algorithms = [ - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), ] ) + @test nested_alg isa AIE.NestedAlgorithm - @test AIE.max_iterations(nested_alg) == 3 - @test length(nested_alg.algorithms) == 3 - - problem = TestProblem([1.0, 2.0]) state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) - - # First algorithm: 1 iteration × 1 increment = 1 - # Second algorithm: 2 iterations × 2 increment = 4 - # Third algorithm: 1 iteration × 1 increment = 1 - # Total: 1 + 4 + 1 = 6 - @test state.iterate ≈ [6.0, 6.0] - @test state.iteration == 3 - end - - @testset "Edge cases" begin - # Test with single nested algorithm - nested_alg = AIE.nested_algorithm(1) do i - return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - end - - problem = TestProblem([1.0]) - state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0]) - - @test state.iterate ≈ [1.0] - @test state.iteration == 1 + # Two child algorithms: 1 inner step + 2 inner steps = 3 total + # `state.iterate .+= 1` calls. + @test state.iteration == 2 + @test state.iterate ≈ [3.0, 3.0] end end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 01ca6e7..34cb305 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -167,7 +167,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -184,7 +186,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -198,13 +202,9 @@ end messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) - stopping_criterion = StopWhenConverged(tol = 1.0e-10) - cache = ITensorNetworksNext.beliefpropagation( - tn, - messages; - maxiter = 10, - stopping_criterion + tn, messages; + stopping_criterion = (; maxiter = 10, tol = 1.0e-10) ) z_bp = exp(bethe_free_energy(tn, cache)) diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl deleted file mode 100644 index dba2570..0000000 --- a/test/test_dmrg.jl +++ /dev/null @@ -1,34 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm -using Test: @test, @testset - -@testset "select_algorithm(dmrg, ...)" begin - operator = "operator" - init = "init" - nsweeps = 3 - regions = ["region1", "region2"] - maxdim = [10, 20] - cutoff = 1.0e-7 - algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) - @test algorithm isa AIE.NestedAlgorithm - @test length(algorithm.algorithms) == nsweeps - - maxdims = [10, 20, 20] - cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] - algorithm′ = AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion( - regions[j]; - maxdim = maxdims[i], - cutoff = cutoffs[i] - ) - end - end - for i in 1:nsweeps - for j in 1:length(regions) - @test algorithm.algorithms[i].algorithms[j] == - algorithm′.algorithms[i].algorithms[j] - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl deleted file mode 100644 index 01881d9..0000000 --- a/test/test_sweeping.jl +++ /dev/null @@ -1,65 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset - -struct TestProblem <: AIE.Problem -end - -struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) - -function AI.solve_loop!(problem::TestProblem, algorithm::TestRegion, state::AIE.State) - new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) - state.iterate = [state.iterate; [new_iterate]] - return state -end - -@testset "Sweeping" begin - @testset "TestRegion" begin - algorithm = TestRegion("region"; foo = 1, bar = 2) - @test algorithm isa AIE.NonIterativeAlgorithm - @test algorithm isa AIE.Algorithm - @test algorithm isa AI.Algorithm - @test algorithm.region == "region" - @test algorithm.kwargs == (; foo = 1, bar = 2) - - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [(; region = "region", foo = 1, bar = 2)] - end - @testset "Sweep" begin - algorithm = AIE.nested_algorithm(3) do i - return TestRegion("region$i"; foo = i, bar = 2i) - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "region1", foo = 1, bar = 2), - (; region = "region2", foo = 2, bar = 4), - (; region = "region3", foo = 3, bar = 6), - ] - end - @testset "Sweeping" begin - algorithm = AIE.nested_algorithm(2) do i - AIE.nested_algorithm(3) do j - return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) - end - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), - (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), - (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), - (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), - (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), - (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), - ] - end -end