Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ OptimKit = "0.4"
Printf = "1"
Random = "1"
Statistics = "1"
TensorKit = "0.16.2"
TensorKit = "0.16.5"
TensorOperations = "5"
TupleTools = "1.6.0"
VectorInterface = "0.4, 0.5"
Expand Down
16 changes: 8 additions & 8 deletions src/Defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Module containing default algorithm parameter values and arguments.
* `svd_rrule_min_krylovdim=$(Defaults.svd_rrule_min_krylovdim)` : Minimal Krylov dimension of the reverse-rule algorithm (if it is a Krylov algorithm).
* `svd_rrule_verbosity=$(Defaults.svd_rrule_verbosity)` : SVD gradient output verbosity.
* `svd_rrule_alg=:$(Defaults.svd_rrule_alg)` : Reverse-rule algorithm for the SVD gradient.
- `:full` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum
- `:trunc` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace
- `:FullPullback` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum
- `:TruncPullback` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace
- `:GMRES` : GMRES iterative linear solver, see [`KrylovKit.GMRES`](@extref)
- `:BiCGStab` : BiCGStab iterative linear solver, see [`KrylovKit.BiCGStab`](@extref)
- `:Arnoldi` : Arnoldi Krylov algorithm, see the [`KrylovKit.Arnoldi`](@extref)
Expand All @@ -58,8 +58,8 @@ Module containing default algorithm parameter values and arguments.
- `:Lanczos` : Lanczos algorithm for symmetric/Hermitian matrices, see [`KrylovKit.Lanczos`](@extref)
- `:BlockLanczos` : Block version of `:Lanczos` for repeated extremal eigenvalues, see [`KrylovKit.BlockLanczos`](@extref)
* `eigh_rrule_alg=:$(Defaults.eigh_rrule_alg)` : Reverse-rule algorithm for the `eigh` gradient.
- `:full` : Full pullback algorithm for eigendecompositions, see [`PEPSKit.FullEighPullback`](@ref).
- `:trunc` : Truncated reverse-mode algorithm for eigendecompositions, see [`PEPSKit.TruncEighPullback`](@ref).
- `:FullPullback` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum
- `:TruncPullback` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace
* `eigh_rrule_verbosity=$(Defaults.eigh_rrule_verbosity)` : eigh gradient output verbosity.

## Projectors
Expand Down Expand Up @@ -126,18 +126,18 @@ const svd_fwd_alg = :DefaultAlgorithm # ∈ {:<MatrixAlgebraKit.SVDAlgorithms>,
const svd_rrule_tol = ctmrg_tol
const svd_rrule_min_krylovdim = 48
const svd_rrule_verbosity = -1
const svd_rrule_alg = :full # ∈ {:full, :trunc, :GMRES, :BiCGStab, :Arnoldi}
const svd_rrule_alg = :FullPullback # ∈ {:FullPullback, :TruncPullback, :GMRES, :BiCGStab, :Arnoldi}
const krylovdim_factor = 1.4

# eigh forward & reverse
const eigh_fwd_alg = :DefaultAlgorithm # ∈ {:<MatrixAlgebraKit.EighAlgorithms>, :Lanczos, :BlockLanczos}
const eigh_rrule_alg = :full # ∈ {:full, :trunc}
const eigh_rrule_alg = :FullPullback # ∈ {:FullPullback, :TruncPullback}
const eigh_rrule_verbosity = 0

# QR forward & reverse
const qr_fwd_alg = :Householder
const qr_fwd_alg = :DefaultAlgorithm
const qr_fwd_positive = true
const qr_rrule_alg = :qr
const qr_rrule_alg = :FullPullback
const qr_rrule_verbosity = 0

# Projectors
Expand Down
39 changes: 21 additions & 18 deletions src/utility/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
"""
$(TYPEDEF)

Eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_pullback!`.
Reverse-rule algorithm which wraps MatrixAlgebraKit's full pullback methods,
see [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!), [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!), [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!).

## Fields

$(TYPEDFIELDS)

## Constructors

FullEighPullback(; kwargs...)
FullPullback(; kwargs...)

Construct a `FullEighPullback` algorithm struct from the following keyword arguments:
Construct a `FullPullback` algorithm struct from the following keyword arguments:

* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Absolute tolerance for idendifying degenerate subspaces.
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
"""
@kwdef struct FullEighPullback
@kwdef struct FullPullback
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
verbosity::Int = 0
end

"""
$(TYPEDEF)

Truncated eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_trunc_pullback!`.

Truncated reverse-rule algorithm which wraps MatrixAlgebraKit's truncated pullback methods,
see [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) and [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!).
## Fields

$(TYPEDFIELDS)

## Constructors

TruncEighPullback(; kwargs...)
TruncPullback(; kwargs...)

Construct a `TruncEighPullback` algorithm struct from the following keyword arguments:
Construct a `TruncPullback` algorithm struct from the following keyword arguments:

* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Absolute tolerance for idendifying degenerate subspaces.
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
"""
@kwdef struct TruncEighPullback
@kwdef struct TruncPullback
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
verbosity::Int = 0
end
Expand Down Expand Up @@ -78,8 +81,8 @@ Construct a `EighAdjoint` algorithm struct based on the following keyword argume
Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied
by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the
following:
- `:full` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum
- `:trunc` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace
- `:FullPullback` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum
- `:TruncPullback` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace

!!! note
Manually specifying a `rrule_alg` is considered expert-mode usage, and should only be done when full control over the implementation is desired.
Expand All @@ -101,10 +104,10 @@ const EIGH_FWD_SYMBOLS = IdDict{Symbol, Any}(
:BlockLanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> IterEigh(; alg = BlockLanczos(; tol, krylovdim), kwargs...),
)
const EIGH_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}(
:full => FullEighPullback, :trunc => TruncEighPullback,
:FullPullback => FullPullback, :TruncPullback => TruncPullback,
)

_default_eigh_rrule_alg(::MatrixAlgebraKit.Algorithm) = :full
_default_eigh_rrule_alg(::MatrixAlgebraKit.Algorithm) = :FullPullback

function EighAdjoint(; fwd_alg = (;), rrule_alg = (;))
# parse forward algorithm
Expand All @@ -131,7 +134,7 @@ function EighAdjoint(; fwd_alg = (;), rrule_alg = (;))
haskey(EIGH_RRULE_SYMBOLS, rrule_kwargs.alg) ||
throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)"))
rrule_type = EIGH_RRULE_SYMBOLS[rrule_kwargs.alg]
if rrule_type <: Union{FullEighPullback, TruncEighPullback}
if rrule_type <: Union{FullPullback, TruncPullback}
rrule_kwargs = (; rrule_kwargs.degeneracy_atol, rrule_kwargs.verbosity)
end

Expand Down Expand Up @@ -191,7 +194,7 @@ Construct an `IterEigh` algorithm struct based on the following keyword argument
fallback_threshold::Float64 = Inf
start_vector = deterministic_start_vector
end
_default_eigh_rrule_alg(::IterEigh) = :trunc
_default_eigh_rrule_alg(::IterEigh) = :TruncPullback

# Compute eigh data block-wise using KrylovKit algorithm
function MatrixAlgebraKit.eigh_trunc!(f, alg::TruncatedAlgorithm{<:IterEigh})
Expand Down Expand Up @@ -300,11 +303,11 @@ end
function ChainRulesCore.rrule(
::typeof(eigh_trunc!),
t::AbstractTensorMap,
alg::EighAdjoint{<:TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, <:FullEighPullback}
alg::EighAdjoint{<:TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, <:FullPullback}
)

D, V = eigh_full!(t; alg.fwd_alg.alg)
(D̃, ), inds = truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
(D̃, ), inds = truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
truncerror = truncation_error(diagview(D), inds)

gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)
Expand All @@ -329,7 +332,7 @@ function ChainRulesCore.rrule(
::typeof(eigh_trunc!),
t,
alg::EighAdjoint{F, R}
) where {F, R <: TruncEighPullback}
) where {F, R <: TruncPullback}
D, V, truncerror = eigh_trunc(t, alg)
gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)

Expand Down
17 changes: 4 additions & 13 deletions src/utility/qr.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
"""
$(TYPEDEF)

QR reverse-rule algorithm which wraps MatrixAlgebraKit's `qr_pullback!`.
"""
@kwdef struct QRPullback
verbosity::Int = 0
end

"""
$(TYPEDEF)

Wrapper for a QR decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`.

## Fields
Expand All @@ -29,7 +20,7 @@ Construct a `QRAdjoint` algorithm struct based on the following keyword argument
- `:DefaultAlgorithm` : MatrixAlgebraKit's [default QR algorithm](@extref MatrixAlgebraKit.DefaultAlgorithm) for a given matrix type.
- `:Householder` : MatrixAlgebraKit's [`Householder`](@extref MatrixAlgebraKit.Householder)
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following:
- `:qr` : MatrixAlgebraKit's [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!)
- `:FullPullback` : MatrixAlgebraKit's [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!)

!!! note
Manually specifying a `rrule_alg` is considered expert-mode usage, and should only be done when full control over the implementation is desired.
Expand All @@ -41,11 +32,11 @@ struct QRAdjoint{F, R}
end

const QR_FWD_SYMBOLS = IdDict{Symbol, Any}(
# :DefaultAlgorithm => DefaultAlgorithm, # TODO: broken, needs to be fixed
:DefaultAlgorithm => DefaultAlgorithm,
:Householder => Householder,
)
const QR_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}(
:qr => QRPullback
:FullPullback => FullPullback
)

function QRAdjoint(; fwd_alg = (;), rrule_alg = (;))
Expand Down Expand Up @@ -100,7 +91,7 @@ function ChainRulesCore.rrule(
::typeof(left_orth!),
t::AbstractTensorMap,
alg::QRAdjoint{F, R},
) where {F <: MatrixAlgebraKit.Algorithm, R <: QRPullback}
) where {F <: MatrixAlgebraKit.Algorithm, R <: FullPullback}
QR = left_orth(t, alg)
gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)

Expand Down
64 changes: 9 additions & 55 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,6 @@ const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreEx
"""
$(TYPEDEF)

SVD reverse-rule algorithm which wraps MatrixAlgebraKit's `svd_pullback!`.

## Fields

$(TYPEDFIELDS)

## Constructors

FullSVDPullback(; kwargs...)

Construct a `FullSVDPullback` algorithm struct from the following keyword arguments:

* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values.
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
"""
@kwdef struct FullSVDPullback
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
verbosity::Int = 0
end

"""
$(TYPEDEF)

SVD reverse-rule algorithm which wraps MatrixAlgebraKit's `svd_trunc_pullback!`.

## Fields

$(TYPEDFIELDS)

## Constructors

TruncSVDPullback(; kwargs...)

Construct a `TruncSVDPullback` algorithm struct from the following keyword arguments:

* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values.
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
"""
@kwdef struct TruncSVDPullback
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
verbosity::Int = 0
end

"""
$(TYPEDEF)

Wrapper for a SVD algorithm `fwd_alg` with a defined reverse rule `rrule_alg`.

## Fields
Expand Down Expand Up @@ -82,8 +36,8 @@ Construct a `SVDAdjoint` algorithm struct based on the following keyword argumen
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=:$(Defaults.svd_rrule_alg))`:
Reverse-rule algorithm for differentiating the SVD. Can be supplied by an `Algorithm`
instance directly or as a `NamedTuple` where `alg` is one of the following:
- `:full` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum
- `:trunc` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace
- `:FullPullback` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum
- `:TruncPullback` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace
- `:GMRES` : GMRES iterative linear solver, see [`KrylovKit.GMRES`](@extref)
- `:BiCGStab` : BiCGStab iterative linear solver, see [`KrylovKit.BiCGStab`](@extref)
- `:Arnoldi` : Arnoldi Krylov algorithm, see the [`KrylovKit.Arnoldi`](@extref KrylovKit.Arnoldi)
Expand All @@ -108,11 +62,11 @@ const SVD_FWD_SYMBOLS = IdDict{Symbol, Any}(
:GKL => (; tol = 1.0e-14, krylovdim = 25, kwargs...) -> IterSVD(; alg = GKL(; tol, krylovdim), kwargs...),
)
const SVD_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}(
:full => FullSVDPullback, :trunc => TruncSVDPullback,
:FullPullback => FullPullback, :TruncPullback => TruncPullback,
:GMRES => GMRES, :BiCGStab => BiCGStab, :Arnoldi => Arnoldi
)

_default_svd_rrule_alg(::MatrixAlgebraKit.Algorithm) = :full
_default_svd_rrule_alg(::MatrixAlgebraKit.Algorithm) = :FullPullback

function SVDAdjoint(; fwd_alg = (;), rrule_alg = (;))
# parse forward SVD algorithm
Expand Down Expand Up @@ -143,11 +97,11 @@ function SVDAdjoint(; fwd_alg = (;), rrule_alg = (;))
rrule_type = SVD_RRULE_SYMBOLS[rrule_kwargs.alg]

# IterSVD is incompatible with tsvd rrule -> default to Arnoldi
if rrule_type <: FullSVDPullback && fwd_algorithm isa IterSVD
if rrule_type <: FullPullback && fwd_algorithm isa IterSVD
rrule_type = Arnoldi
end

if rrule_type <: Union{FullSVDPullback, TruncSVDPullback}
if rrule_type <: Union{FullPullback, TruncPullback}
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg = nothing, tol = 0.0, krylovdim = 0)) # remove `alg`, `tol` and `krylovdim` keyword arguments
else
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg = nothing, degeneracy_atol = 0.0)) # remove `alg` and `degeneracy_atol` keyword arguments
Expand Down Expand Up @@ -211,7 +165,7 @@ Construct an `IterSVD` algorithm struct based on the following keyword arguments
fallback_threshold::Float64 = Inf
start_vector = deterministic_start_vector
end
_default_svd_rrule_alg(::IterSVD) = :trunc
_default_svd_rrule_alg(::IterSVD) = :TruncPullback

random_start_vector(t::AbstractMatrix) = randn(scalartype(t), size(t, 1))
deterministic_start_vector(t::AbstractMatrix) = ones(scalartype(t), size(t, 1))
Expand Down Expand Up @@ -316,7 +270,7 @@ function ChainRulesCore.rrule(
::typeof(svd_trunc!),
t::AbstractTensorMap,
alg::SVDAdjoint{F, R}
) where {F <: TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, R <: FullSVDPullback}
) where {F <: TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, R <: FullPullback}
# TODO: filter out any decomposition algorithm that doesn't give access to the full spectrum

# requires access to the full decomposition
Expand Down Expand Up @@ -347,7 +301,7 @@ function ChainRulesCore.rrule(
::typeof(svd_trunc!),
t,
alg::SVDAdjoint{F, R},
) where {F, R <: TruncSVDPullback}
) where {F, R <: TruncPullback}
U, S, V⁺, ϵ = svd_trunc(t, alg)
gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)

Expand Down
4 changes: 2 additions & 2 deletions test/gradients/c4v_ctmrg_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ gradtol = 1.0e-4
ctmrg_verbosity = 1
ctmrg_algs = [[:C4vCTMRG]]
projector_algs = [[:C4vEighProjector, :C4vQRProjector]]
decomposition_rrule_algs = [[:full, :trunc, :qr]]
decomposition_rrule_algs = [[:FullPullback, :TruncPullback]]
gradient_algs = [[nothing, :GeomSum, :ManualIter, :LinSolver, :EigSolver]] # they all use :fixed mode by default (except for nothing)
steps = -0.01:0.005:0.01

Expand All @@ -34,7 +34,7 @@ allowed_rrule_algs = Dict(
)

# be selective on which configurations to test the naive gradient for
naive_gradient_combinations = [(:C4vCTMRG, :C4vEighProjector, :full), (:C4vCTMRG, :C4vQRProjector, :qr)]
naive_gradient_combinations = [(:C4vCTMRG, :C4vEighProjector, :FullPullback), (:C4vCTMRG, :C4vQRProjector, :FullPullback)]
naive_gradient_done = Set()

## Tests
Expand Down
8 changes: 4 additions & 4 deletions test/gradients/ctmrg_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ gradtol = 1.0e-4
ctmrg_verbosity = 0
ctmrg_algs = [[:SequentialCTMRG, :SimultaneousCTMRG], [:SequentialCTMRG, :SimultaneousCTMRG]]
projector_algs = [[:HalfInfiniteProjector, :FullInfiniteProjector], [:HalfInfiniteProjector, :FullInfiniteProjector]]
svd_rrule_algs = [[:full, :trunc, :Arnoldi], [:full, :Arnoldi]]
svd_rrule_algs = [[:FullPullback, :TruncPullback, :Arnoldi], [:FullPullback, :Arnoldi]]
gradient_algs = [
[nothing, :GeomSum, :ManualIter, :LinSolver, :EigSolver],
[:GeomSum, :ManualIter, :LinSolver, :EigSolver],
Expand All @@ -29,9 +29,9 @@ steps = -0.01:0.005:0.01

# don't check naive AD gradients for all algorithm combinations, since it's slow
naive_gradient_combinations = [
(:SimultaneousCTMRG, :HalfInfiniteProjector, :full),
(:SimultaneousCTMRG, :FullInfiniteProjector, :full),
(:SequentialCTMRG, :HalfInfiniteProjector, :full),
(:SimultaneousCTMRG, :HalfInfiniteProjector, :FullPullback),
(:SimultaneousCTMRG, :FullInfiniteProjector, :FullPullback),
(:SequentialCTMRG, :HalfInfiniteProjector, :FullPullback),
]
naive_gradient_done = Set()

Expand Down
6 changes: 3 additions & 3 deletions test/utility/eigh_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ r = 0.5 * (r + r') # make r Hermitian
R = randn(space(r))
R = 0.5 * (R + R')

full_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :full))
trunc_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :trunc))
iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :trunc))
full_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :FullPullback))
trunc_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :TruncPullback))
iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :TruncPullback))

@testset "Non-truncated eigh" begin
l_full, g_full = withgradient(A -> lossfun(A, full_alg, R), r)
Expand Down
Loading
Loading