diff --git a/Project.toml b/Project.toml index ca18875e4..ea4c90d22 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ OptimKit = "0.4" Printf = "1" Random = "1" Statistics = "1" -TensorKit = "0.16.2" +TensorKit = "0.16.5" TensorOperations = "5" TupleTools = "1.6.0" VectorInterface = "0.4, 0.5" diff --git a/src/Defaults.jl b/src/Defaults.jl index 6cd0ca652..3b4bb963f 100644 --- a/src/Defaults.jl +++ b/src/Defaults.jl @@ -40,8 +40,8 @@ Module containing default algorithm parameter values and arguments. * `svd_rrule_min_krylovdim=$(Defaults.svd_rrule_min_krylovdim)` : Minimal Krylov dimension of the reverse-rule algorithm (if it is a Krylov algorithm). * `svd_rrule_verbosity=$(Defaults.svd_rrule_verbosity)` : SVD gradient output verbosity. * `svd_rrule_alg=:$(Defaults.svd_rrule_alg)` : Reverse-rule algorithm for the SVD gradient. - - `:full` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum - - `:trunc` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace + - `:FullPullback` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum + - `:TruncPullback` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace - `:GMRES` : GMRES iterative linear solver, see [`KrylovKit.GMRES`](@extref) - `:BiCGStab` : BiCGStab iterative linear solver, see [`KrylovKit.BiCGStab`](@extref) - `:Arnoldi` : Arnoldi Krylov algorithm, see the [`KrylovKit.Arnoldi`](@extref) @@ -58,8 +58,8 @@ Module containing default algorithm parameter values and arguments. - `:Lanczos` : Lanczos algorithm for symmetric/Hermitian matrices, see [`KrylovKit.Lanczos`](@extref) - `:BlockLanczos` : Block version of `:Lanczos` for repeated extremal eigenvalues, see [`KrylovKit.BlockLanczos`](@extref) * `eigh_rrule_alg=:$(Defaults.eigh_rrule_alg)` : Reverse-rule algorithm for the `eigh` gradient. - - `:full` : Full pullback algorithm for eigendecompositions, see [`PEPSKit.FullEighPullback`](@ref). - - `:trunc` : Truncated reverse-mode algorithm for eigendecompositions, see [`PEPSKit.TruncEighPullback`](@ref). + - `:FullPullback` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum + - `:TruncPullback` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace * `eigh_rrule_verbosity=$(Defaults.eigh_rrule_verbosity)` : eigh gradient output verbosity. ## Projectors @@ -126,18 +126,18 @@ const svd_fwd_alg = :DefaultAlgorithm # ∈ {:, const svd_rrule_tol = ctmrg_tol const svd_rrule_min_krylovdim = 48 const svd_rrule_verbosity = -1 -const svd_rrule_alg = :full # ∈ {:full, :trunc, :GMRES, :BiCGStab, :Arnoldi} +const svd_rrule_alg = :FullPullback # ∈ {:FullPullback, :TruncPullback, :GMRES, :BiCGStab, :Arnoldi} const krylovdim_factor = 1.4 # eigh forward & reverse const eigh_fwd_alg = :DefaultAlgorithm # ∈ {:, :Lanczos, :BlockLanczos} -const eigh_rrule_alg = :full # ∈ {:full, :trunc} +const eigh_rrule_alg = :FullPullback # ∈ {:FullPullback, :TruncPullback} const eigh_rrule_verbosity = 0 # QR forward & reverse -const qr_fwd_alg = :Householder +const qr_fwd_alg = :DefaultAlgorithm const qr_fwd_positive = true -const qr_rrule_alg = :qr +const qr_rrule_alg = :FullPullback const qr_rrule_verbosity = 0 # Projectors diff --git a/src/utility/eigh.jl b/src/utility/eigh.jl index eec53ad42..0db3bb7ef 100644 --- a/src/utility/eigh.jl +++ b/src/utility/eigh.jl @@ -1,7 +1,8 @@ """ $(TYPEDEF) -Eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_pullback!`. +Reverse-rule algorithm which wraps MatrixAlgebraKit's full pullback methods, +see [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!), [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!), [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!). ## Fields @@ -9,13 +10,14 @@ $(TYPEDFIELDS) ## Constructors - FullEighPullback(; kwargs...) + FullPullback(; kwargs...) -Construct a `FullEighPullback` algorithm struct from the following keyword arguments: +Construct a `FullPullback` algorithm struct from the following keyword arguments: +* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Absolute tolerance for idendifying degenerate subspaces. * `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. """ -@kwdef struct FullEighPullback +@kwdef struct FullPullback degeneracy_atol::Real = Defaults.rrule_degeneracy_atol verbosity::Int = 0 end @@ -23,21 +25,22 @@ end """ $(TYPEDEF) -Truncated eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_trunc_pullback!`. - +Truncated reverse-rule algorithm which wraps MatrixAlgebraKit's truncated pullback methods, +see [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) and [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!). ## Fields $(TYPEDFIELDS) ## Constructors - TruncEighPullback(; kwargs...) + TruncPullback(; kwargs...) -Construct a `TruncEighPullback` algorithm struct from the following keyword arguments: +Construct a `TruncPullback` algorithm struct from the following keyword arguments: +* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Absolute tolerance for idendifying degenerate subspaces. * `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. """ -@kwdef struct TruncEighPullback +@kwdef struct TruncPullback degeneracy_atol::Real = Defaults.rrule_degeneracy_atol verbosity::Int = 0 end @@ -78,8 +81,8 @@ Construct a `EighAdjoint` algorithm struct based on the following keyword argume Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: - - `:full` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum - - `:trunc` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace + - `:FullPullback` : MatrixAlgebraKit's [`eigh_pullback!`](@extref MatrixAlgebraKit.eigh_pullback!) that requires access to the full spectrum + - `:TruncPullback` : MatrixAlgebraKit's [`eigh_trunc_pullback!`](@extref MatrixAlgebraKit.eigh_trunc_pullback!) solving a Sylvester equation on the truncated subspace !!! note Manually specifying a `rrule_alg` is considered expert-mode usage, and should only be done when full control over the implementation is desired. @@ -101,10 +104,10 @@ const EIGH_FWD_SYMBOLS = IdDict{Symbol, Any}( :BlockLanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> IterEigh(; alg = BlockLanczos(; tol, krylovdim), kwargs...), ) const EIGH_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( - :full => FullEighPullback, :trunc => TruncEighPullback, + :FullPullback => FullPullback, :TruncPullback => TruncPullback, ) -_default_eigh_rrule_alg(::MatrixAlgebraKit.Algorithm) = :full +_default_eigh_rrule_alg(::MatrixAlgebraKit.Algorithm) = :FullPullback function EighAdjoint(; fwd_alg = (;), rrule_alg = (;)) # parse forward algorithm @@ -131,7 +134,7 @@ function EighAdjoint(; fwd_alg = (;), rrule_alg = (;)) haskey(EIGH_RRULE_SYMBOLS, rrule_kwargs.alg) || throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) rrule_type = EIGH_RRULE_SYMBOLS[rrule_kwargs.alg] - if rrule_type <: Union{FullEighPullback, TruncEighPullback} + if rrule_type <: Union{FullPullback, TruncPullback} rrule_kwargs = (; rrule_kwargs.degeneracy_atol, rrule_kwargs.verbosity) end @@ -191,7 +194,7 @@ Construct an `IterEigh` algorithm struct based on the following keyword argument fallback_threshold::Float64 = Inf start_vector = deterministic_start_vector end -_default_eigh_rrule_alg(::IterEigh) = :trunc +_default_eigh_rrule_alg(::IterEigh) = :TruncPullback # Compute eigh data block-wise using KrylovKit algorithm function MatrixAlgebraKit.eigh_trunc!(f, alg::TruncatedAlgorithm{<:IterEigh}) @@ -300,11 +303,11 @@ end function ChainRulesCore.rrule( ::typeof(eigh_trunc!), t::AbstractTensorMap, - alg::EighAdjoint{<:TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, <:FullEighPullback} + alg::EighAdjoint{<:TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, <:FullPullback} ) D, V = eigh_full!(t; alg.fwd_alg.alg) - (D̃, Ṽ), inds = truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc) + (D̃, Ṽ), inds = truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc) truncerror = truncation_error(diagview(D), inds) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) @@ -329,7 +332,7 @@ function ChainRulesCore.rrule( ::typeof(eigh_trunc!), t, alg::EighAdjoint{F, R} - ) where {F, R <: TruncEighPullback} + ) where {F, R <: TruncPullback} D, V, truncerror = eigh_trunc(t, alg) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) diff --git a/src/utility/qr.jl b/src/utility/qr.jl index 9c26dd074..a512821c8 100644 --- a/src/utility/qr.jl +++ b/src/utility/qr.jl @@ -1,15 +1,6 @@ """ $(TYPEDEF) -QR reverse-rule algorithm which wraps MatrixAlgebraKit's `qr_pullback!`. -""" -@kwdef struct QRPullback - verbosity::Int = 0 -end - -""" -$(TYPEDEF) - Wrapper for a QR decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. ## Fields @@ -29,7 +20,7 @@ Construct a `QRAdjoint` algorithm struct based on the following keyword argument - `:DefaultAlgorithm` : MatrixAlgebraKit's [default QR algorithm](@extref MatrixAlgebraKit.DefaultAlgorithm) for a given matrix type. - `:Householder` : MatrixAlgebraKit's [`Householder`](@extref MatrixAlgebraKit.Householder) * `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: - - `:qr` : MatrixAlgebraKit's [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!) + - `:FullPullback` : MatrixAlgebraKit's [`qr_pullback!`](@extref MatrixAlgebraKit.qr_pullback!) !!! note Manually specifying a `rrule_alg` is considered expert-mode usage, and should only be done when full control over the implementation is desired. @@ -41,11 +32,11 @@ struct QRAdjoint{F, R} end const QR_FWD_SYMBOLS = IdDict{Symbol, Any}( - # :DefaultAlgorithm => DefaultAlgorithm, # TODO: broken, needs to be fixed + :DefaultAlgorithm => DefaultAlgorithm, :Householder => Householder, ) const QR_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( - :qr => QRPullback + :FullPullback => FullPullback ) function QRAdjoint(; fwd_alg = (;), rrule_alg = (;)) @@ -100,7 +91,7 @@ function ChainRulesCore.rrule( ::typeof(left_orth!), t::AbstractTensorMap, alg::QRAdjoint{F, R}, - ) where {F <: MatrixAlgebraKit.Algorithm, R <: QRPullback} + ) where {F <: MatrixAlgebraKit.Algorithm, R <: FullPullback} QR = left_orth(t, alg) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) diff --git a/src/utility/svd.jl b/src/utility/svd.jl index f07b3d796..83093a46f 100644 --- a/src/utility/svd.jl +++ b/src/utility/svd.jl @@ -3,52 +3,6 @@ const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreEx """ $(TYPEDEF) -SVD reverse-rule algorithm which wraps MatrixAlgebraKit's `svd_pullback!`. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - FullSVDPullback(; kwargs...) - -Construct a `FullSVDPullback` algorithm struct from the following keyword arguments: - -* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values. -* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. -""" -@kwdef struct FullSVDPullback - degeneracy_atol::Real = Defaults.rrule_degeneracy_atol - verbosity::Int = 0 -end - -""" -$(TYPEDEF) - -SVD reverse-rule algorithm which wraps MatrixAlgebraKit's `svd_trunc_pullback!`. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - TruncSVDPullback(; kwargs...) - -Construct a `TruncSVDPullback` algorithm struct from the following keyword arguments: - -* `degeneracy_atol::Real=$(Defaults.rrule_degeneracy_atol)` : Broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values. -* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. -""" -@kwdef struct TruncSVDPullback - degeneracy_atol::Real = Defaults.rrule_degeneracy_atol - verbosity::Int = 0 -end - -""" -$(TYPEDEF) - Wrapper for a SVD algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. ## Fields @@ -82,8 +36,8 @@ Construct a `SVDAdjoint` algorithm struct based on the following keyword argumen * `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=:$(Defaults.svd_rrule_alg))`: Reverse-rule algorithm for differentiating the SVD. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: - - `:full` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum - - `:trunc` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace + - `:FullPullback` : MatrixAlgebraKit's [`svd_pullback!`](@extref MatrixAlgebraKit.svd_pullback!) that requires access to the full spectrum + - `:TruncPullback` : MatrixAlgebraKit's [`svd_trunc_pullback!`](@extref MatrixAlgebraKit.svd_trunc_pullback!) solving a Sylvester equation on the truncated subspace - `:GMRES` : GMRES iterative linear solver, see [`KrylovKit.GMRES`](@extref) - `:BiCGStab` : BiCGStab iterative linear solver, see [`KrylovKit.BiCGStab`](@extref) - `:Arnoldi` : Arnoldi Krylov algorithm, see the [`KrylovKit.Arnoldi`](@extref KrylovKit.Arnoldi) @@ -108,11 +62,11 @@ const SVD_FWD_SYMBOLS = IdDict{Symbol, Any}( :GKL => (; tol = 1.0e-14, krylovdim = 25, kwargs...) -> IterSVD(; alg = GKL(; tol, krylovdim), kwargs...), ) const SVD_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( - :full => FullSVDPullback, :trunc => TruncSVDPullback, + :FullPullback => FullPullback, :TruncPullback => TruncPullback, :GMRES => GMRES, :BiCGStab => BiCGStab, :Arnoldi => Arnoldi ) -_default_svd_rrule_alg(::MatrixAlgebraKit.Algorithm) = :full +_default_svd_rrule_alg(::MatrixAlgebraKit.Algorithm) = :FullPullback function SVDAdjoint(; fwd_alg = (;), rrule_alg = (;)) # parse forward SVD algorithm @@ -143,11 +97,11 @@ function SVDAdjoint(; fwd_alg = (;), rrule_alg = (;)) rrule_type = SVD_RRULE_SYMBOLS[rrule_kwargs.alg] # IterSVD is incompatible with tsvd rrule -> default to Arnoldi - if rrule_type <: FullSVDPullback && fwd_algorithm isa IterSVD + if rrule_type <: FullPullback && fwd_algorithm isa IterSVD rrule_type = Arnoldi end - if rrule_type <: Union{FullSVDPullback, TruncSVDPullback} + if rrule_type <: Union{FullPullback, TruncPullback} rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg = nothing, tol = 0.0, krylovdim = 0)) # remove `alg`, `tol` and `krylovdim` keyword arguments else rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg = nothing, degeneracy_atol = 0.0)) # remove `alg` and `degeneracy_atol` keyword arguments @@ -211,7 +165,7 @@ Construct an `IterSVD` algorithm struct based on the following keyword arguments fallback_threshold::Float64 = Inf start_vector = deterministic_start_vector end -_default_svd_rrule_alg(::IterSVD) = :trunc +_default_svd_rrule_alg(::IterSVD) = :TruncPullback random_start_vector(t::AbstractMatrix) = randn(scalartype(t), size(t, 1)) deterministic_start_vector(t::AbstractMatrix) = ones(scalartype(t), size(t, 1)) @@ -316,7 +270,7 @@ function ChainRulesCore.rrule( ::typeof(svd_trunc!), t::AbstractTensorMap, alg::SVDAdjoint{F, R} - ) where {F <: TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, R <: FullSVDPullback} + ) where {F <: TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, R <: FullPullback} # TODO: filter out any decomposition algorithm that doesn't give access to the full spectrum # requires access to the full decomposition @@ -347,7 +301,7 @@ function ChainRulesCore.rrule( ::typeof(svd_trunc!), t, alg::SVDAdjoint{F, R}, - ) where {F, R <: TruncSVDPullback} + ) where {F, R <: TruncPullback} U, S, V⁺, ϵ = svd_trunc(t, alg) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) diff --git a/test/gradients/c4v_ctmrg_gradients.jl b/test/gradients/c4v_ctmrg_gradients.jl index 06ca98513..5b5230968 100644 --- a/test/gradients/c4v_ctmrg_gradients.jl +++ b/test/gradients/c4v_ctmrg_gradients.jl @@ -23,7 +23,7 @@ gradtol = 1.0e-4 ctmrg_verbosity = 1 ctmrg_algs = [[:C4vCTMRG]] projector_algs = [[:C4vEighProjector, :C4vQRProjector]] -decomposition_rrule_algs = [[:full, :trunc, :qr]] +decomposition_rrule_algs = [[:FullPullback, :TruncPullback]] gradient_algs = [[nothing, :GeomSum, :ManualIter, :LinSolver, :EigSolver]] # they all use :fixed mode by default (except for nothing) steps = -0.01:0.005:0.01 @@ -34,7 +34,7 @@ allowed_rrule_algs = Dict( ) # be selective on which configurations to test the naive gradient for -naive_gradient_combinations = [(:C4vCTMRG, :C4vEighProjector, :full), (:C4vCTMRG, :C4vQRProjector, :qr)] +naive_gradient_combinations = [(:C4vCTMRG, :C4vEighProjector, :FullPullback), (:C4vCTMRG, :C4vQRProjector, :FullPullback)] naive_gradient_done = Set() ## Tests diff --git a/test/gradients/ctmrg_gradients.jl b/test/gradients/ctmrg_gradients.jl index 0c277e144..c3e18ff06 100644 --- a/test/gradients/ctmrg_gradients.jl +++ b/test/gradients/ctmrg_gradients.jl @@ -20,7 +20,7 @@ gradtol = 1.0e-4 ctmrg_verbosity = 0 ctmrg_algs = [[:SequentialCTMRG, :SimultaneousCTMRG], [:SequentialCTMRG, :SimultaneousCTMRG]] projector_algs = [[:HalfInfiniteProjector, :FullInfiniteProjector], [:HalfInfiniteProjector, :FullInfiniteProjector]] -svd_rrule_algs = [[:full, :trunc, :Arnoldi], [:full, :Arnoldi]] +svd_rrule_algs = [[:FullPullback, :TruncPullback, :Arnoldi], [:FullPullback, :Arnoldi]] gradient_algs = [ [nothing, :GeomSum, :ManualIter, :LinSolver, :EigSolver], [:GeomSum, :ManualIter, :LinSolver, :EigSolver], @@ -29,9 +29,9 @@ steps = -0.01:0.005:0.01 # don't check naive AD gradients for all algorithm combinations, since it's slow naive_gradient_combinations = [ - (:SimultaneousCTMRG, :HalfInfiniteProjector, :full), - (:SimultaneousCTMRG, :FullInfiniteProjector, :full), - (:SequentialCTMRG, :HalfInfiniteProjector, :full), + (:SimultaneousCTMRG, :HalfInfiniteProjector, :FullPullback), + (:SimultaneousCTMRG, :FullInfiniteProjector, :FullPullback), + (:SequentialCTMRG, :HalfInfiniteProjector, :FullPullback), ] naive_gradient_done = Set() diff --git a/test/utility/eigh_wrapper.jl b/test/utility/eigh_wrapper.jl index d7ff4ed5f..4c437d4fc 100644 --- a/test/utility/eigh_wrapper.jl +++ b/test/utility/eigh_wrapper.jl @@ -26,9 +26,9 @@ r = 0.5 * (r + r') # make r Hermitian R = randn(space(r)) R = 0.5 * (R + R') -full_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :full)) -trunc_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :trunc)) -iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :trunc)) +full_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :FullPullback)) +trunc_alg = EighAdjoint(; fwd_alg = (; alg = :QRIteration), rrule_alg = (; alg = :TruncPullback)) +iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :TruncPullback)) @testset "Non-truncated eigh" begin l_full, g_full = withgradient(A -> lossfun(A, full_alg, R), r) diff --git a/test/utility/svd_wrapper.jl b/test/utility/svd_wrapper.jl index b1667810b..54219d2f5 100644 --- a/test/utility/svd_wrapper.jl +++ b/test/utility/svd_wrapper.jl @@ -24,8 +24,8 @@ Random.seed!(12345678) r = randn(dtype, ℂ^m, ℂ^n) R = randn(space(r)) -full_alg = SVDAdjoint(; rrule_alg = (; alg = :full, degeneracy_atol = 1.0e-13)) -trunc_alg = SVDAdjoint(; rrule_alg = (; alg = :trunc, degeneracy_atol = 1.0e-13)) +full_alg = SVDAdjoint(; rrule_alg = (; alg = :FullPullback, degeneracy_atol = 1.0e-13)) +trunc_alg = SVDAdjoint(; rrule_alg = (; alg = :TruncPullback, degeneracy_atol = 1.0e-13)) iter_alg = SVDAdjoint(; fwd_alg = (; alg = :GKL)) @testset "Non-truncated SVD" begin