diff --git a/Project.toml b/Project.toml index af43fc925..668ec1cf2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.17.0" authors = ["Jutho Haegeman, Lukas Devos"] [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -19,7 +20,6 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" [weakdeps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -28,7 +28,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] -TensorKitAdaptExt = "Adapt" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" @@ -55,7 +54,7 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.1" +TensorOperations = "5.5" TupleTools = "1.5" VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index b19537e3f..1ef3da0fb 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -2,6 +2,7 @@ ```@meta CurrentModule = TensorKit +CollapsedDocStrings = true ``` ## Type hierarchy @@ -184,12 +185,6 @@ repartition! twist! ``` -```@docs -TensorKit.add_permute! -TensorKit.add_braid! -TensorKit.add_transpose! -``` - ### Tensor map composition, traces, contractions and tensor products ```@docs diff --git a/ext/TensorKitAdaptExt.jl b/ext/TensorKitAdaptExt.jl deleted file mode 100644 index 4d2693b7b..000000000 --- a/ext/TensorKitAdaptExt.jl +++ /dev/null @@ -1,26 +0,0 @@ -module TensorKitAdaptExt - -using TensorKit -using TensorKit: AdjointTensorMap -using Adapt - -function Adapt.adapt_structure(to, x::TensorMap) - data′ = adapt(to, x.data) - return TensorMap{eltype(data′)}(data′, space(x)) -end -function Adapt.adapt_structure(to, x::AdjointTensorMap) - return adjoint(adapt(to, parent(x))) -end -function Adapt.adapt_structure(to, x::DiagonalTensorMap) - data′ = adapt(to, x.data) - return DiagonalTensorMap(data′, x.domain) -end -function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A} - A′ = TensorKit.similarstoragetype(A, T) - return BraidingTensor{T, S, A′}(space(x), x.adjoint) -end -function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A} - return BraidingTensor{T′, S, TA}(space(x), x.adjoint) -end - -end diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 1a5c28f7c..aa7320f32 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -7,6 +7,8 @@ import CUDA.cuRAND: rand as curand, rand! as curand!, randn as curandn, randn! a using Strided: StridedViews using CUDA.CUDACore.KernelAbstractions: @kernel, @index, get_backend +using Adapt: Adapt + using TensorKit using TensorKit.Factorizations using TensorKit.Strided diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 02ca2d5d6..016749fce 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -155,7 +155,3 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) return tf end end - -function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS} - return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.CUDACore.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...) -end diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 4b2111c50..a87e6c97a 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -51,19 +51,19 @@ end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} # returning a CuSectorVector wrecks things in truncate_{co}domain # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) + return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) @eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat) # returning a CuSectorVector wrecks things in truncate_{co}domain # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) + return Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) + return SectorDict(c => Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) end diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index c3bc3e26c..1fccdd9e6 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -1,11 +1,11 @@ for transform in (:permute, :transpose) - add_transform! = Symbol(:add_, transform, :!) - add_transform_pullback = Symbol(add_transform!, :_pullback) + transform! = Symbol(transform, :!) + transform_pullback = Symbol(transform!, :_pullback) @eval @is_primitive( DefaultCtx, ReverseMode, Tuple{ - typeof(TK.$add_transform!), + typeof(TK.$transform!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, Number, Number, Vararg{Any}, @@ -13,7 +13,7 @@ for transform in (:permute, :transpose) ) @eval function Mooncake.rrule!!( - ::CoDual{typeof(TK.$add_transform!)}, + ::CoDual{typeof(TK.$transform!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, @@ -30,17 +30,17 @@ for transform in (:permute, :transpose) # if we need to compute Δa, it is faster to allocate an intermediate permuted A # and store that instead of repeating the permutation in the pullback each time. - # effectively, we replace `add_permute` by `add ∘ permute`. + # effectively, we replace `permute!/transpose!` by `add ∘ permute/transpose`. Ap = if _needs_tangent(α) Ap = $transform(A, p) add!(C, Ap, α, β) Ap else - TK.$add_transform!(C, A, p, α, β, ba...) + TK.$transform!(C, A, p, α, β, ba...) nothing end - function $add_transform_pullback(::NoRData) + function $transform_pullback(::NoRData) copy!(C, C_cache) # ΔA @@ -50,10 +50,10 @@ for transform in (:permute, :transpose) TC = VectorInterface.promote_scale(ΔC, α) if scalartype(ΔA) <: Real && !(TC <: Real) ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) - TK.$add_transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) + TK.$transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) add!(ΔA, real(ΔAc)) else - TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + TK.$transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) end ΔAr = NoRData() @@ -64,7 +64,7 @@ for transform in (:permute, :transpose) return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end - return C_ΔC, $add_transform_pullback + return C_ΔC, $transform_pullback end end @@ -72,7 +72,7 @@ end DefaultCtx, ReverseMode, Tuple{ - typeof(TK.add_braid!), + typeof(TK.braid!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, IndexTuple, Number, Number, Vararg{Any}, @@ -80,7 +80,7 @@ end ) function Mooncake.rrule!!( - ::CoDual{typeof(TK.add_braid!)}, + ::CoDual{typeof(TK.braid!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, levels_Δlevels::CoDual{<:IndexTuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, @@ -98,17 +98,17 @@ function Mooncake.rrule!!( # if we need to compute Δa, it is faster to allocate an intermediate braided A # and store that instead of repeating the permutation in the pullback each time. - # effectively, we replace `add_permute` by `add ∘ permute`. + # effectively, we replace `braid!` by `add ∘ braid`. Ap = if _needs_tangent(α) Ap = braid(A, p, levels) add!(C, Ap, α, β) Ap else - TK.add_braid!(C, A, p, levels, α, β, ba...) + TK.braid!(C, A, p, levels, α, β, ba...) nothing end - function add_braid!_pullback(::NoRData) + function braid!_pullback(::NoRData) copy!(C, C_cache) # ΔA @@ -118,10 +118,10 @@ function Mooncake.rrule!!( TC = VectorInterface.promote_scale(ΔC, α) if scalartype(ΔA) <: Real && !(TC <: Real) ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) - TK.add_braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) + TK.braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) add!(ΔA, real(ΔAc)) else - TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + TK.braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) end ΔAr = NoRData() @@ -132,7 +132,7 @@ function Mooncake.rrule!!( return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end - return C_ΔC, add_braid!_pullback + return C_ΔC, braid!_pullback end # both are needed for correctly capturing every dispatch diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index 3c75fe2da..abbef5004 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -60,7 +60,7 @@ # if length(q[1]) == 0 # ip = invperm(linearize(p)) # pΔA = _repartition(ip, A) -# TK.add_transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) +# TK.transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) # return NoRData() # end # # if length(q[1]) == 1 diff --git a/src/TensorKit.jl b/src/TensorKit.jl index d8361ac79..6a2828588 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -91,8 +91,7 @@ export left_orth, right_orth, left_null, right_null, isisometric, isunitary, project_isometric, project_isometric!, isposdef, isposdef!, sylvester, rank, cond -export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, - repartition! +export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! export catdomain, catcodomain, absorb, absorb! # tensor operations @@ -150,6 +149,8 @@ import Base.Meta using Random: Random, rand!, randn! +using Adapt: Adapt + # Auxiliary files #----------------- include("auxiliary/auxiliary.jl") diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index cde772982..758bb708a 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -32,7 +32,7 @@ function planaradd!( α::Number, β::Number, backend, allocator ) - return add_transpose!(C, A, p, α, β, backend) + return transpose!(C, A, p, α, β, backend) end # insert default backend @@ -173,7 +173,7 @@ function planarcontract!( A′ = TO.tensoralloc_add( scalartype(A), A, (oindA, cindA), false, Val(true), allocator ) - add_transpose!(A′, A, (oindA, cindA), One(), Zero(), backend) + transpose!(A′, A, (oindA, cindA), One(), Zero(), backend) end if cindB == codB && oindB == domB @@ -182,7 +182,7 @@ function planarcontract!( B′ = TensorOperations.tensoralloc_add( scalartype(B), B, (cindB, oindB), false, Val(true), allocator ) - add_transpose!(B′, B, (cindB, oindB), One(), Zero(), backend) + transpose!(B′, B, (cindB, oindB), One(), Zero(), backend) end mul!(C, A′, B′, α, β) (oindA == codA && cindA == domA) || TO.tensorfree!(A′, allocator) diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index ca484e77b..820e87375 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -50,6 +50,8 @@ Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tu return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...)) end +Adapt.adapt_structure(to, x::AdjointTensorMap) = adjoint(Adapt.adapt(to, parent(x))) + # Show #------ function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index d28b2e1df..a1b7dbd02 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -53,6 +53,14 @@ function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T} return BraidingTensor{T}(V[2], V[1], adjoint) end +function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A} + A′ = TensorKit.similarstoragetype(A, T) + return BraidingTensor{T, S, A′}(space(x), x.adjoint) +end +function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A} + return BraidingTensor{T′, S, TA}(space(x), x.adjoint) +end + function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A} return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint) end @@ -221,11 +229,11 @@ function planarcontract!( I = sectortype(C) BraidingStyle(I) isa Bosonic && - return add_permute!(C, B, (reverse(cindB), oindB), α, β, backend) + return permute!(C, B, (reverse(cindB), oindB), α, β, backend, allocator) # Non-bosonic case: factor into a cyclic transpose (no crossings) + a single Artin braid # that swaps the two contracted legs, producing the R-symbol that A encodes. Naively - # using a single `add_braid!` is wrong: it would resolve cyclic moves as crossings and + # using a single `braid!` is wrong: it would resolve cyclic moves as crossings and # pick up spurious R-symbol factors. B_in_layout = (cindB == codB && oindB == domB) if B_in_layout @@ -234,7 +242,7 @@ function planarcontract!( B′ = TO.tensoralloc_add( scalartype(B), B, (cindB, oindB), false, Val(true), allocator ) - add_transpose!(B′, B, (cindB, oindB), One(), Zero(), backend) + transpose!(B′, B, (cindB, oindB), One(), Zero(), backend, allocator) end levelsA = A.adjoint ? (1, 2, 2, 1) : (2, 1, 1, 2) @@ -244,9 +252,9 @@ function planarcontract!( ntuple(Returns(3), N - 2)..., ) - add_braid!( + braid!( C, B′, ((2, 1), ntuple(i -> i + 2, N - 2)), - levels, α, β, backend, + levels, α, β, backend, allocator ) B_in_layout || TO.tensorfree!(B′, allocator) @@ -274,11 +282,11 @@ function planarcontract!( I = sectortype(C) BraidingStyle(I) isa Bosonic && - return add_permute!(C, A, (oindA, reverse(cindA)), α, β, backend) + return permute!(C, A, (oindA, reverse(cindA)), α, β, backend, allocator) # Non-bosonic case: cyclic transpose A → (oindA, cindA) (no crossings), then a single # Artin braid swaps A′'s last two indices, producing the R-symbol that B encodes. Naively - # using a single `add_braid!` is wrong: it would resolve cyclic moves as crossings and + # using a single `braid!` is wrong: it would resolve cyclic moves as crossings and # pick up spurious R-symbol factors. A_in_layout = (oindA == codA && cindA == domA) @@ -288,7 +296,7 @@ function planarcontract!( A′ = TO.tensoralloc_add( scalartype(A), A, (oindA, cindA), false, Val(true), allocator ) - add_transpose!(A′, A, (oindA, cindA), One(), Zero(), backend) + transpose!(A′, A, (oindA, cindA), One(), Zero(), backend, allocator) end levelsB = B.adjoint ? (1, 2, 2, 1) : (2, 1, 1, 2) @@ -299,9 +307,9 @@ function planarcontract!( levelsB[cindB[1]], levelsB[cindB[2]], ) - add_braid!( + braid!( C, A′, (ntuple(identity, M), (N, N - 1)), - levels, α, β, backend, + levels, α, β, backend, allocator ) A_in_layout || TO.tensorfree!(A′, allocator) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index b2ac4134b..5fc6682e3 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -133,6 +133,11 @@ function Base.convert(::Type{DiagonalTensorMap}, d::Dict{Symbol, Any}) return convert(DiagonalTensorMap, convert(TensorMap, d)) end +function Adapt.adapt_structure(to, x::DiagonalTensorMap) + data′ = Adapt.adapt(to, x.data) + return DiagonalTensorMap(data′, x.domain) +end + # Complex, real and imaginary parts #----------------------------------- for f in (:real, :imag, :complex) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 3108abb17..655f0ca7f 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -1,28 +1,10 @@ -# Index manipulations -#--------------------- - -# find the scalartype after applying operations: take into account fusion and/or braiding -# might need to become Float or Complex to capture complex recoupling coefficients but don't alter precision -for (operation, manipulation) in ( - :flip => :sector, :twist => :braiding, - :transpose => :fusion, :permute => :sector, :braid => :sector, - ) - promote_op = Symbol(:promote_, operation) - manipulation_scalartype = Symbol(manipulation, :scalartype) - - @eval begin - $promote_op(t::AbstractTensorMap) = $promote_op(typeof(t)) - $promote_op(::Type{T}) where {T <: AbstractTensorMap} = - $promote_op(scalartype(T), sectortype(T)) - $promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} = - sectorscalartype(I) <: Integer ? T : - sectorscalartype(I) <: Real ? float(T) : complex(T) - # TODO: currently the manipulations all use sectorscalartype, change to: - # $manipulation_scalartype(I) <: Integer ? T : - # $manipulation_scalartype(I) <: Real ? float(T) : complex(T) - end -end +# ============= +# Reweighting +# ============= +# ------ +# flip +# ------ """ flip(t::AbstractTensorMap, I) -> t′::AbstractTensorMap @@ -46,33 +28,221 @@ function flip(t::AbstractTensorMap, I; inv::Bool = false) return t′ end +# --------- +# twist(!) +# --------- +function has_shared_twist(t, inds) + I = sectortype(t) + if BraidingStyle(I) == NoBraiding() + for i in inds + cs = sectors(space(t, i)) + all(isunit, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs")) + end + return true + elseif BraidingStyle(I) == Bosonic() + return true + else + for i in inds + cs = sectors(space(t, i)) + all(isone ∘ twist, cs) || return false + end + return true + end +end + +""" + twist!(t::AbstractTensorMap, i::Int; inv::Bool = false) -> t + twist!(t::AbstractTensorMap, inds; inv::Bool = false) -> t + +Apply a twist to the `i`th index of `t`, or all indices in `inds`, storing the result in `t`. +If `inv=true`, use the inverse twist. + +See [`twist`](@ref) for creating a new tensor. +""" +function twist!(t::AbstractTensorMap, inds; inv::Bool = false) + if !all(in(allind(t)), inds) + msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices." + throw(ArgumentError(msg)) + end + (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && + throw(ArgumentError("Can't in-place twist a real tensor with complex sector type")) + has_shared_twist(t, inds) && return t + + N₁ = numout(t) + for (f₁, f₂) in fusiontrees(t) + θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds) + inv && (θ = θ') + scale!(t[f₁, f₂], θ) + end + return t +end + +""" + twist(tsrc::AbstractTensorMap, i::Int; inv::Bool = false, copy::Bool = false) -> tdst + twist(tsrc::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) -> tdst + +Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor. +If `inv = true`, use the inverse twist. +If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. + +See [`twist!`](@ref) for storing the result in place. +""" +function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) + if has_shared_twist(t, inds) + return copy ? Base.copy(t) : t + end + tdst = similar(t, promote_twist(t)) + copy!(tdst, t) + return twist!(tdst, inds; inv) +end + +# ========================= +# Space insertion/removal +# ========================= + +# Methods which change the number of indices, implement using `Val(i)` for type inference +""" + insertleftunit( + tsrc::AbstractTensorMap, i = numind(t) + 1; + conj = false, dual = false, copy = false + ) -> tdst + +Insert a trivial vector space, isomorphic to the underlying field, at position `i`, +which can be specified as an `Int` or as `Val(i)` for improved type stability. +More specifically, adds a left monoidal unit or its dual. +Insert a trivial vector space, isomorphic to the underlying field, before position `i`, +which should satisfy `1 ≤ i ≤ numind(t) + 1` +and can be specified as an `Int` or as `Val(i)` for improved type stability, +More specifically, add a left monoidal unit (or its dual) of the space associated with index `i`. +The new index appears at position `i` in the new tensor, +namely in its codomain for `1 ≤ i ≤ numout(t)` and in its domain otherwise. +If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. + +See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) where {i}), +[`removeunit`](@ref removeunit(::AbstractTensorMap, ::Val{i}) where {i}). +""" +function insertleftunit( + t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1); + copy::Bool = false, conj::Bool = false, dual::Bool = false + ) where {i} + W = insertleftunit(space(t), Val(i); conj, dual) + if t isa TensorMap + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + else + tdst = similar(t, W) + for (c, b) in blocks(t) + copy!(block(tdst, c), b) + end + return tdst + end +end + +""" + insertrightunit( + tsrc::AbstractTensorMap, i = numind(t); + conj = false, dual = false, copy = false + ) -> tdst + +Insert a trivial vector space, isomorphic to the underlying field, after position `i`, +which should satisfy `0 ≤ i ≤ numind(t)` +and can be specified as an `Int` or as `Val(i)` for improved type stability, +More specifically, add a right monoidal unit (or its dual) of the space associated with index `i`. +The new index appears at position `i+1` in the new tensor, +namely in its codomain for `0 ≤ i ≤ numout(t)` and in its domain otherwise. + +If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. + +See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) where {i}), +[`removeunit`](@ref removeunit(::AbstractTensorMap, ::Val{i}) where {i}). +""" +function insertrightunit( + t::AbstractTensorMap, ::Val{i} = Val(numind(t)); + copy::Bool = false, conj::Bool = false, dual::Bool = false + ) where {i} + W = insertrightunit(space(t), Val(i); conj, dual) + if t isa TensorMap + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + else + tdst = similar(t, W) + for (c, b) in blocks(t) + copy!(block(tdst, c), b) + end + return tdst + end +end + +""" + removeunit(tsrc::AbstractTensorMap, i; copy = false) -> tdst + +This removes a trivial tensor product factor at position `1 ≤ i ≤ N`, where `i` +can be specified as an `Int` or as `Val(i)` for improved type stability. +For this to work, that factor has to be isomorphic to the field of scalars. + +If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. + +This operation undoes the work of [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) where {i}) +and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) where {i}). +""" +function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i} + W = removeunit(space(t), Val(i)) + if t isa TensorMap + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + else + tdst = similar(t, W) + for (c, b) in blocks(t) + copy!(block(tdst, c), b) + end + return tdst + end +end + +# TODO: fusion/splitting of indices + +# ============================ +# Index rearrangements +# ============================ + +# -------------- +# permute(!) +# -------------- """ - permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple) - -> tdst + permute!(tdst, tsrc, (p₁, p₂)::Index2Tuple, α = 1, β = 0, [backend], [allocator]) -> tdst -Write into `tdst` the result of permuting the indices of `tsrc`. +Compute `tdst = β * tdst + α * permute(tsrc, (p₁, p₂))`, writing the result into `tdst`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. - -See [`permute`](@ref) for creating a new tensor and [`add_permute!`](@ref) for a more general version. +Optionally specify a `backend` and `allocator` for the underlying array operation. + +See also [`permute`](@ref) for creating a new tensor. """ @propagate_inbounds function Base.permute!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple + tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, + α::Number = One(), β::Number = Zero(), + backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) - return add_permute!(tdst, tsrc, p, One(), Zero()) + @boundscheck spacecheck_transform(permute, tdst, tsrc, p) + levels = ntuple(identity, numind(tsrc)) + return @inbounds braid!(tdst, tsrc, p, levels, α, β, backend, allocator) end """ - permute(tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple; copy::Bool = false) -> tdst::TensorMap + permute( + tsrc, (p₁, p₂)::Index2Tuple; copy = false, + backend = DefaultBackend(), allocator = DefaultAllocator() + ) -> tdst::TensorMap Return tensor `tdst` obtained by permuting the indices of `tsrc`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +Optionally specify a `backend` and `allocator` for the underlying array operation. -To permute into an existing destination, see [permute!](@ref) and [`add_permute!`](@ref) +See also [`permute!`](@ref) for writing into an existing destination. """ -function permute(t::AbstractTensorMap, p::Index2Tuple; copy::Bool = false) +function permute( + t::AbstractTensorMap, p::Index2Tuple; + copy::Bool = false, backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() + ) # share data if possible if !copy if p == (codomainind(t), domainind(t)) @@ -84,14 +254,15 @@ function permute(t::AbstractTensorMap, p::Index2Tuple; copy::Bool = false) # general case tdst = similar(t, promote_permute(t), permute(space(t), p)) - return @inbounds permute!(tdst, t, p) + levels = ntuple(identity, numind(t)) + return @inbounds braid!(tdst, t, p, levels, One(), Zero(), backend, allocator) end -function permute(t::AdjointTensorMap, (p₁, p₂)::Index2Tuple; copy::Bool = false) +function permute(t::AdjointTensorMap, (p₁, p₂)::Index2Tuple; kwargs...) p₁′ = adjointtensorindices(t, p₂) p₂′ = adjointtensorindices(t, p₁) - return adjoint(permute(adjoint(t), (p₁′, p₂′); copy)) + return adjoint(permute(adjoint(t), (p₁′, p₂′); kwargs...)) end -permute(t::AbstractTensorMap, p::IndexTuple; copy::Bool = false) = permute(t, (p, ()); copy) +permute(t::AbstractTensorMap, p::IndexTuple; kwargs...) = permute(t, (p, ()); kwargs...) function has_shared_permute(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) return (p₁ === codomainind(t) && p₂ === domainind(t)) @@ -100,8 +271,8 @@ function has_shared_permute(t::TensorMap, (p₁, p₂)::Index2Tuple) if p₁ === codomainind(t) && p₂ === domainind(t) return true elseif sectortype(t) === Trivial - stridet = i -> stride(t[], i) - sizet = i -> size(t[], i) + stridet = Base.Fix1(stride, t[]) + sizet = Base.Fix1(size, t[]) canfuse1, d1, s1 = TO._canfuse(sizet.(p₁), stridet.(p₁)) canfuse2, d2, s2 = TO._canfuse(sizet.(p₂), stridet.(p₂)) return canfuse1 && canfuse2 && s1 == 1 && (d2 == 1 || s2 == d1) @@ -115,29 +286,37 @@ function has_shared_permute(t::AdjointTensorMap, (p₁, p₂)::Index2Tuple) return has_shared_permute(t', (p₁′, p₂′)) end -# Braid +# ------------- +# braid(!) +# ------------- """ - braid!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, - (p₁, p₂)::Index2Tuple, levels::Tuple) - -> tdst + braid!(tdst, tsrc, (p₁, p₂)::Index2Tuple, levels::IndexTuple, α = 1, β = 0, [backend], [allocator]) -> tdst -Write into `tdst` the result of braiding the indices of `tsrc`. +Compute `tdst = β * tdst + α * braid(tsrc, (p₁, p₂), levels)`, writing the result into `tdst`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. Here, `levels` is a tuple of length `numind(tsrc)` that assigns a level or height to the indices of `tsrc`, which determines whether they will braid over or under any other index with which they have to change places. +Optionally specify a `backend` and `allocator` for the underlying array operation. -See [`braid`](@ref) for creating a new tensor and [`add_braid!`](@ref) for a more general version. +See also [`braid`](@ref) for creating a new tensor. """ @propagate_inbounds function braid!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple + tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple, + α::Number = One(), β::Number = Zero(), + backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) - return add_braid!(tdst, tsrc, p, levels, One(), Zero()) + @boundscheck spacecheck_transform(braid, tdst, tsrc, p, levels) + levels1 = TupleTools.getindices(levels, codomainind(tsrc)) + levels2 = TupleTools.getindices(levels, domainind(tsrc)) + transformer = treebraider(tdst, tsrc, p, (levels1, levels2)) + return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend, allocator) end """ - braid(tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, levels::IndexTuple; - copy::Bool = false) - -> tdst::TensorMap + braid( + tsrc, (p₁, p₂)::Index2Tuple, levels::IndexTuple; copy = false, + backend = DefaultBackend(), allocator = DefaultAllocator() + ) -> tdst::TensorMap Return tensor `tdst` obtained by braiding the indices of `tsrc`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. @@ -145,277 +324,178 @@ Here, `levels` is a tuple of length `numind(tsrc)` that assigns a level or heigh which determines whether they will braid over or under any other index with which they have to change places. If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +Optionally specify a `backend` and `allocator` for the underlying array operation. -To braid into an existing destination, see [braid!](@ref) and [`add_braid!`](@ref) +See also [`braid!`](@ref) for writing into an existing destination. """ function braid( - t::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple; copy::Bool = false + t::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple; + copy::Bool = false, backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) - length(levels) == numind(t) || throw(ArgumentError("invalid levels")) + length(levels) == numind(t) || throw(ArgumentError(lazy"length of levels should be $(numind(t)), got $(length(levels))")) - BraidingStyle(sectortype(t)) isa SymmetricBraiding && return permute(t, p; copy) (!copy && p == (codomainind(t), domainind(t))) && return t # general case tdst = similar(t, promote_braid(t), permute(space(t), p)) - return @inbounds braid!(tdst, t, p, levels) + return @inbounds braid!(tdst, t, p, levels, One(), Zero(), backend, allocator) +end +function braid( + t::AdjointTensorMap, (p₁, p₂)::Index2Tuple, levels::IndexTuple; + kwargs... + ) + p₁′ = adjointtensorindices(t, p₂) + p₂′ = adjointtensorindices(t, p₁) + perm = adjointtensorindices(adjoint(t), ntuple(identity, numind(t))) + levels′ = TupleTools.getindices(levels, perm) + return adjoint(braid(adjoint(t), (p₁′, p₂′), levels′; kwargs...)) end -# TODO: braid for `AdjointTensorMap`; think about how to map the `levels` argument. -# Transpose +# ---------------- +# transpose(!) +# ---------------- _transpose_indices(t::AbstractTensorMap) = (reverse(domainind(t)), reverse(codomainind(t))) """ - transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, - (p₁, p₂)::Index2Tuple) - -> tdst + transpose!(tdst, tsrc, (p₁, p₂)::Index2Tuple, α = 1, β = 0, [backend], [allocator]) -> tdst -Write into `tdst` the result of transposing the indices of `tsrc`. +Compute `tdst = β * tdst + α * transpose(tsrc, (p₁, p₂))`, writing the result into `tdst`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. The new index positions should be attainable without any indices crossing each other, i.e., -the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permutation of `(codomainind(tsrc)..., reverse(domainind(tsrc))...)`. +the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permutation of +`(codomainind(tsrc)..., reverse(domainind(tsrc))...)`. +Optionally specify a `backend` and `allocator` for the underlying array operation. -See [`transpose`](@ref) for creating a new tensor and [`add_transpose!`](@ref) for a more general version. +See also [`transpose`](@ref) for creating a new tensor. """ +function LinearAlgebra.transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) + return transpose!(tdst, tsrc, _transpose_indices(tsrc)) +end @propagate_inbounds function LinearAlgebra.transpose!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(tsrc) + tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, + α::Number = One(), β::Number = Zero(), + backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) - return add_transpose!(tdst, tsrc, (p₁, p₂), One(), Zero()) + @boundscheck spacecheck_transform(transpose, tdst, tsrc, p) + transformer = treetransposer(tdst, tsrc, p) + return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend, allocator) end """ - transpose(tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple; - copy::Bool=false) - -> tdst::TensorMap + transpose( + tsrc, (p₁, p₂)::Index2Tuple; copy = false, + backend = DefaultBackend(), allocator = DefaultAllocator() + ) -> tdst::TensorMap Return tensor `tdst` obtained by transposing the indices of `tsrc`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. The new index positions should be attainable without any indices crossing each other, i.e., -the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permutation of `(codomainind(tsrc)..., reverse(domainind(tsrc))...)`. +the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permutation of +`(codomainind(tsrc)..., reverse(domainind(tsrc))...)`. If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +Optionally specify a `backend` and `allocator` for the underlying array operation. -To permute into an existing destination, see [permute!](@ref) and [`add_permute!`](@ref) +See also [`transpose!`](@ref) for writing into an existing destination. """ function LinearAlgebra.transpose( t::AbstractTensorMap, p::Index2Tuple = _transpose_indices(t); - copy::Bool = false + copy::Bool = false, backend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) - sectortype(t) === Trivial && return permute(t, p; copy) + sectortype(t) === Trivial && return permute(t, p; copy, backend, allocator) (!copy && p == (codomainind(t), domainind(t))) && return t # general case tdst = similar(t, promote_transpose(t), permute(space(t), p)) - return @inbounds transpose!(tdst, t, p) + return @inbounds transpose!(tdst, t, p, One(), Zero(), backend, allocator) end function LinearAlgebra.transpose( t::AdjointTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(t); - copy::Bool = false + copy::Bool = false, backend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) p₁′ = map(n -> adjointtensorindex(t, n), p₂) p₂′ = map(n -> adjointtensorindex(t, n), p₁) - return adjoint(transpose(adjoint(t), (p₁′, p₂′); copy = copy)) + return adjoint(transpose(adjoint(t), (p₁′, p₂′); copy, backend, allocator)) end +# ------------------- +# repartition(!) +# ------------------- """ - repartition!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) -> tdst + repartition!(tdst, tsrc, α = 1, β = 0, [backend], [allocator]) -> tdst -Write into `tdst` the result of repartitioning the indices of `tsrc`. This is just a special -case of a transposition that only changes the number of in- and outgoing indices. +Compute `tdst = β * tdst + α * repartition(tsrc)`, writing the result into `tdst`. +This is a special case of `transpose!` that only changes the partition of indices between +codomain and domain, without changing their cyclic order. +Optionally specify a `backend` and `allocator` for the underlying array operation. -See [`repartition`](@ref) for creating a new tensor. +See also [`repartition`](@ref) for creating a new tensor. """ -@propagate_inbounds function repartition!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) +@propagate_inbounds function repartition!( + tdst::AbstractTensorMap, tsrc::AbstractTensorMap, + α::Number = One(), β::Number = Zero(), + backend::AbstractBackend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() + ) check_spacetype(tdst, tsrc) numind(tsrc) == numind(tdst) || throw(ArgumentError("tsrc and tdst should have an equal amount of indices")) - all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...) - p₁ = ntuple(i -> all_inds[i], numout(tdst)) - p₂ = reverse(ntuple(i -> all_inds[i + numout(tdst)], numin(tdst))) - return transpose!(tdst, tsrc, (p₁, p₂)) + p₁, p₂ = let all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...) + ntuple(i -> all_inds[i], numout(tdst)), reverse(ntuple(i -> all_inds[i + numout(tdst)], numin(tdst))) + end + return transpose!(tdst, tsrc, (p₁, p₂), α, β, backend, allocator) end """ repartition( - tsrc::AbstractTensorMap{T, S}, N₁::Int, N₂::Int; copy::Bool=false - ) where {T, S} -> tdst::AbstractTensorMap{T, S, N₁, N₂} + tsrc, N₁::Int, N₂::Int = numind(tsrc) - N₁; copy = false, + backend = DefaultBackend(), allocator = DefaultAllocator() + ) -> tdst -Return tensor `tdst` obtained by repartitioning the indices of `t`. -The codomain and domain of `tdst` correspond to the first `N₁` and last `N₂` spaces of `t`, respectively. +Return tensor `tdst` obtained by repartitioning the indices of `tsrc`. +The codomain and domain of `tdst` correspond to the first `N₁` and last `N₂` spaces of `tsrc`, +respectively. If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +Optionally specify a `backend` and `allocator` for the underlying array operation. -To repartition into an existing destination, see [repartition!](@ref). +See also [`repartition!`](@ref) for writing into an existing destination. """ @constprop :aggressive function repartition( - t::AbstractTensorMap, N₁::Int, N₂::Int = numind(t) - N₁; copy::Bool = false + t::AbstractTensorMap, N₁::Int, N₂::Int = numind(t) - N₁; + copy::Bool = false, backend = TO.DefaultBackend(), allocator = TO.DefaultAllocator() ) N₁ + N₂ == numind(t) || throw(ArgumentError("Invalid repartition: $(numind(t)) to ($N₁, $N₂)")) - all_inds = (codomainind(t)..., reverse(domainind(t))...) - p₁ = ntuple(i -> all_inds[i], N₁) - p₂ = reverse(ntuple(i -> all_inds[i + N₁], N₂)) - return transpose(t, (p₁, p₂); copy) -end - -# Twist -function has_shared_twist(t, inds) - I = sectortype(t) - if BraidingStyle(I) == NoBraiding() - for i in inds - cs = sectors(space(t, i)) - all(isunit, cs) || throw(SectorMismatch(lazy"Cannot twist sectors $cs")) - end - return true - elseif BraidingStyle(I) == Bosonic() - return true - else - for i in inds - cs = sectors(space(t, i)) - all(isone ∘ twist, cs) || return false - end - return true - end -end - -""" - twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t - twist!(t::AbstractTensorMap, inds; inv::Bool=false) -> t - -Apply a twist to the `i`th index of `t`, or all indices in `inds`, storing the result in `t`. -If `inv=true`, use the inverse twist. - -See [`twist`](@ref) for creating a new tensor. -""" -function twist!(t::AbstractTensorMap, inds; inv::Bool = false) - if !all(in(allind(t)), inds) - msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices." - throw(ArgumentError(msg)) - end - (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && - throw(ArgumentError("Can't in-place twist a real tensor with complex sector type")) - has_shared_twist(t, inds) && return t - - (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && - throw(ArgumentError("No in-place `twist!` for a real tensor with complex sector type")) - - N₁ = numout(t) - for (f₁, f₂) in fusiontrees(t) - θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds) - inv && (θ = θ') - scale!(t[f₁, f₂], θ) - end - return t -end - -""" - twist(tsrc::AbstractTensorMap, i::Int; inv::Bool = false, copy::Bool = false) -> tdst - twist(tsrc::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) -> tdst - -Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor. -If `inv = true`, use the inverse twist. -If `copy = false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. - -See [`twist!`](@ref) for storing the result in place. -""" -function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) - !copy && has_shared_twist(t, inds) && return t - tdst = similar(t, promote_twist(t)) - copy!(tdst, t) - return twist!(tdst, inds; inv) -end - -# Methods which change the number of indices, implement using `Val(i)` for type inference -""" - insertleftunit(tsrc::AbstractTensorMap, i=numind(t) + 1; - conj=false, dual=false, copy=false) -> tdst - -Insert a trivial vector space, isomorphic to the underlying field, at position `i`, -which can be specified as an `Int` or as `Val(i)` for improved type stability. -More specifically, adds a left monoidal unit or its dual. - -If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. - -See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) where {i}), -[`removeunit`](@ref removeunit(::AbstractTensorMap, ::Val{i}) where {i}). -""" -function insertleftunit( - t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1); - copy::Bool = false, conj::Bool = false, dual::Bool = false - ) where {i} - W = insertleftunit(space(t), Val(i); conj, dual) - if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) - else - tdst = similar(t, W) - for (c, b) in blocks(t) - copy!(block(tdst, c), b) - end - return tdst + p₁, p₂ = let all_inds = (codomainind(t)..., reverse(domainind(t))...) + ntuple(i -> all_inds[i], N₁), reverse(ntuple(i -> all_inds[i + N₁], N₂)) end + return transpose(t, (p₁, p₂); copy, backend, allocator) end -""" - insertrightunit(tsrc::AbstractTensorMap, i=numind(t); - conj=false, dual=false, copy=false) -> tdst - -Insert a trivial vector space, isomorphic to the underlying field, after position `i`, -which can be specified as an `Int` or as `Val(i)` for improved type stability. -More specifically, adds a right monoidal unit or its dual. - -If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. - -See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) where {i}), -[`removeunit`](@ref removeunit(::AbstractTensorMap, ::Val{i}) where {i}). -""" -function insertrightunit( - t::AbstractTensorMap, ::Val{i} = Val(numind(t)); - copy::Bool = false, conj::Bool = false, dual::Bool = false - ) where {i} - W = insertrightunit(space(t), Val(i); conj, dual) - if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) - else - tdst = similar(t, W) - for (c, b) in blocks(t) - copy!(block(tdst, c), b) - end - return tdst - end -end - -""" - removeunit(tsrc::AbstractTensorMap, i; copy=false) -> tdst - -This removes a trivial tensor product factor at position `1 ≤ i ≤ N`, where `i` -can be specified as an `Int` or as `Val(i)` for improved type stability. -For this to work, that factor has to be isomorphic to the field of scalars. +#------------------------------------- +# Internal implementations +#------------------------------------- -If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +# find the scalartype after applying operations: take into account fusion and/or braiding +# might need to become Float or Complex to capture complex recoupling coefficients but don't alter precision +for (operation, manipulation) in ( + :flip => :sector, :twist => :braiding, + :transpose => :fusion, :permute => :sector, :braid => :sector, + ) + promote_op = Symbol(:promote_, operation) + manipulation_scalartype = Symbol(manipulation, :scalartype) -This operation undoes the work of [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) where {i}) -and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) where {i}). -""" -function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i} - W = removeunit(space(t), Val(i)) - if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) - else - tdst = similar(t, W) - for (c, b) in blocks(t) - copy!(block(tdst, c), b) - end - return tdst + @eval begin + $promote_op(t::AbstractTensorMap) = $promote_op(typeof(t)) + $promote_op(::Type{T}) where {T <: AbstractTensorMap} = + $promote_op(scalartype(T), sectortype(T)) + $promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} = + $manipulation_scalartype(I) <: Integer ? T : + $manipulation_scalartype(I) <: Real ? float(T) : complex(T) end end -# Fusing and splitting -# TODO: add functionality for easy fusing and splitting of tensor indices - -#------------------------------------- -# Full implementations based on `add` -#------------------------------------- spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args...) = spacecheck_transform(f, space(tdst), space(tsrc), args...) @noinline function spacecheck_transform(f, Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple) @@ -447,67 +527,26 @@ end return nothing end - -""" - add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend...) - -Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after permuting -the indices of `tsrc` according to `(p₁, p₂)`. - -See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_transpose!`](@ref). -""" -@propagate_inbounds function add_permute!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend... - ) - @boundscheck spacecheck_transform(permute, tdst, tsrc, p) - transformer = treepermuter(tdst, tsrc, p) - return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) -end - -""" - add_braid!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend...) - -Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after braiding -the indices of `tsrc` according to `(p₁, p₂)` and `levels`. - -See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transpose!`](@ref). -""" -@propagate_inbounds function add_braid!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple, - α::Number, β::Number, backend::AbstractBackend... - ) - @boundscheck spacecheck_transform(braid, tdst, tsrc, p, levels) - levels1 = TupleTools.getindices(levels, codomainind(tsrc)) - levels2 = TupleTools.getindices(levels, domainind(tsrc)) - # TODO: arg order for tensormaps is different than for fusiontrees - transformer = treebraider(tdst, tsrc, p, (levels1, levels2)) - return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) -end - -""" - add_transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend...) - -Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after transposing -the indices of `tsrc` according to `(p₁, p₂)`. - -See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`add_braid!`](@ref). -""" -@propagate_inbounds function add_transpose!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, - α::Number, β::Number, backend::AbstractBackend... - ) - @boundscheck spacecheck_transform(transpose, tdst, tsrc, p) - transformer = treetransposer(tdst, tsrc, p) - return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) -end - +# Deprecated add_*! wrappers +# -------------------------- +Base.@deprecate( + add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend...), + permute!(tdst, tsrc, p, α, β, backend...) +) +Base.@deprecate( + add_braid!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend...), + braid!(tdst, tsrc, p, levels, α, β, backend...) +) +Base.@deprecate( + add_transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend...), + transpose!(tdst, tsrc, p, α, β, backend...) +) + +# Kernel implementation +# --------------------- @propagate_inbounds function add_transform!( tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, transformer, - α::Number, β::Number, backend::AbstractBackend... + α::Number, β::Number, backend, allocator ) @boundscheck spacecheck_transform(permute, tdst, tsrc, p) @@ -516,13 +555,14 @@ end else I = sectortype(tdst) if I === Trivial - add_trivial_kernel!(tdst, tsrc, p, transformer, α, β, backend...) + TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend, allocator) else - style = FusionStyle(I) - if use_threaded_transform(tdst, transformer) - add_kernel_threaded!(style, tdst, tsrc, p, transformer, α, β, backend...) + ntasks = use_threaded_transform(tdst, transformer) ? get_num_transformer_threads() : 1 + scheduler = ntasks == 1 ? SerialScheduler() : DynamicScheduler(; ntasks, split = :roundrobin) + if tdst isa TensorMap && tsrc isa TensorMap # unpack data fields to avoid specializing + add_transform_kernel!(tdst.data, tsrc.data, p, transformer, α, β, backend, allocator, scheduler) else - add_kernel_nonthreaded!(style, tdst, tsrc, p, transformer, α, β, backend...) + add_transform_kernel!(tdst, tsrc, p, transformer, α, β, backend, allocator, scheduler) end end end @@ -537,267 +577,141 @@ function use_threaded_transform(t::AbstractTensorMap, transformer) return get_num_transformer_threads() > 1 && dim(space(t)) > Strided.MINTHREADLENGTH end -# Trivial implementations -# ----------------------- -function add_trivial_kernel!(tdst, tsrc, p, transformer, α, β, backend...) - TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...) - return nothing -end - -# Non-threaded implementations -# ---------------------------- -function add_kernel_nonthreaded!( - ::UniqueFusion, tdst, tsrc, p, transformer, α, β, backend... +function add_transform_kernel!( + tdst, tsrc, p, transformer, α, β, backend, allocator, scheduler ) - for (f₁, f₂) in fusiontrees(tsrc) - _add_transform_single!(tdst, tsrc, p, (f₁, f₂), transformer, α, β, backend...) - end - return nothing -end -function add_kernel_nonthreaded!( - ::UniqueFusion, tdst, tsrc, p, transformer::AbelianTreeTransformer, α, β, backend... - ) - for subtransformer in transformer.data - _add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) - end - return nothing -end -function add_kernel_nonthreaded!(::FusionStyle, tdst, tsrc, p, transformer, α, β, backend...) - # preallocate buffers - buffers = allocate_buffers(tdst, tsrc, transformer) - - for src in fusionblocks(tsrc) - if length(src) == 1 - _add_transform_single!(tdst, tsrc, p, src, transformer, α, β, backend...) - else - _add_transform_multi!(tdst, tsrc, p, src, transformer, buffers, α, β, backend...) + I = sectortype(tdst) + if FusionStyle(I) === UniqueFusion() + tforeach(fusiontrees(tsrc); scheduler) do (f₁, f₂) + (f₁′, f₂′), coeff = transformer((f₁, f₂)) + @inbounds TO.tensoradd!( + tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, backend, allocator + ) end + return nothing end - return nothing -end -# specialization in the case of TensorMap -function add_kernel_nonthreaded!( - ::FusionStyle, tdst, tsrc, p, transformer::GenericTreeTransformer, α, β, backend... - ) - # preallocate buffers - buffers = allocate_buffers(tdst, tsrc, transformer) + cp = TO.allocator_checkpoint!(allocator) + + # buffers have to be created without race condition: err on the side of caution with a lock + buffer_lock = Threads.ReentrantLock() + + OhMyThreads.@tasks for src in fusionblocks(tsrc) + # setup + OhMyThreads.@set scheduler = scheduler + dst, U = transformer(src) + + if length(src) == 1 # Degenerate block with a single tree: no matmul needed. + (f₁, f₂) = only(fusiontrees(src)) + (f₁′, f₂′) = only(fusiontrees(dst)) + @inbounds TO.tensoradd!( + tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * only(U), β, backend, allocator + ) + else # Multi-tree block: pack → recoupling matmul → unpack. + rows, cols = size(U) + sz_src = size(tsrc[first(fusiontrees(src))...]) + blocksize = prod(sz_src) + buffer = @lock buffer_lock TO.tensoralloc(storagetype(tdst), blocksize * (rows + cols), Val(true), allocator) + ptriv = (ntuple(identity, length(sz_src)), ()) + buffer_dst = StridedView(buffer, (blocksize, rows), (1, blocksize), 0) + buffer_src = StridedView(buffer, (blocksize, cols), (1, blocksize), blocksize * rows) + + # 1. Extract: copy each source block into column i of buffer_src as a flat vector, + # using a trivial permutation so the layout is canonical before the matmul. + @inbounds for (i, (f₁, f₂)) in enumerate(fusiontrees(src)) + TO.tensoradd!( + sreshape(buffer_src[:, i], sz_src), tsrc[f₁, f₂], + ptriv, false, One(), Zero(), backend, allocator + ) + end - for subtransformer in transformer.data - # Special case without intermediate buffers whenever there is only a single block - if length(subtransformer[1]) == 1 - _add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) - else - _add_transform_multi!(tdst, tsrc, p, subtransformer, buffers, α, β, backend...) - end - end - return nothing -end -# ambiguity resolution -function add_kernel_nonthreaded!( - ::UniqueFusion, tdst, tsrc, p, transformer::GenericTreeTransformer, α, β, backend... - ) - throw(ArgumentError("Cannot combine `GenericTreeTransformer` with `UniqueFusion`")) -end -# Threaded implementations -# ------------------------ -function add_kernel_threaded!( - ::UniqueFusion, tdst, tsrc, p, transformer, α, β, backend...; - ntasks::Int = get_num_transformer_threads() - ) - trees = fusiontrees(tsrc) - nblocks = length(trees) - counter = Threads.Atomic{Int}(1) - Threads.@sync for _ in 1:min(ntasks, nblocks) - Threads.@spawn begin - while true - local_counter = Threads.atomic_add!(counter, 1) - local_counter > nblocks && break - @inbounds (f₁, f₂) = trees[local_counter] - _add_transform_single!(tdst, tsrc, p, (f₁, f₂), transformer, α, β, backend...) + # 2. Recoupling: buffer_dst = buffer_src * U^T (each output tree is a linear + # combination of input trees weighted by the recoupling coefficients). + U′ = Adapt.adapt(storagetype(tdst), StridedView(U)) + mul!(buffer_dst, buffer_src, transpose(U′)) + + # 3. Insert: scatter column i of buffer_dst into the destination, applying the + # actual index permutation p in the same tensoradd! call. + @inbounds for (i, (f₃, f₄)) in enumerate(fusiontrees(dst)) + TO.tensoradd!( + tdst[f₃, f₄], sreshape(buffer_dst[:, i], sz_src), + p, false, α, β, backend, allocator + ) end + @lock buffer_lock TO.tensorfree!(buffer, allocator) end end + TO.allocator_reset!(allocator, cp) return nothing end -function add_kernel_threaded!( - ::UniqueFusion, tdst, tsrc, p, transformer::AbelianTreeTransformer, α, β, backend...; - ntasks::Int = get_num_transformer_threads() + +# TensorMap specializations: operate directly on the flat data vector to avoid +# repeated specialization -- this only depends on `numind` and `eltype`. +function add_transform_kernel!( + data_dst::DenseVector, data_src::DenseVector, p, transformer::AbelianTreeTransformer, + α, β, backend, allocator, scheduler ) - nblocks = length(transformer.data) - counter = Threads.Atomic{Int}(1) - Threads.@sync for _ in 1:min(ntasks, nblocks) - Threads.@spawn begin - while true - local_counter = Threads.atomic_add!(counter, 1) - local_counter > nblocks && break - @inbounds subtransformer = transformer.data[local_counter] - _add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) - end - end + tforeach(transformer.data; scheduler) do (coeff, struct_dst, struct_src) + TO.tensoradd!( + StridedView(data_dst, struct_dst...), StridedView(data_src, struct_src...), + p, false, α * coeff, β, backend, allocator + ) end return nothing end - -function add_kernel_threaded!( - ::FusionStyle, tdst, tsrc, p, transformer, α, β, backend...; - ntasks::Int = get_num_transformer_threads() +function add_transform_kernel!( + data_dst::DenseVector, data_src::DenseVector, p, transformer::GenericTreeTransformer, + α, β, backend, allocator, scheduler ) - allblocks = fusionblocks(tsrc) - nblocks = length(allblocks) - - counter = Threads.Atomic{Int}(1) - Threads.@sync for _ in 1:min(ntasks, nblocks) - Threads.@spawn begin - # preallocate buffers for each task - buffers = allocate_buffers(tdst, tsrc, transformer) - - while true - local_counter = Threads.atomic_add!(counter, 1) - local_counter > nblocks && break - @inbounds src = allblocks[local_counter] - if length(src) == 1 - _add_transform_single!(tdst, tsrc, p, src, transformer, α, β, backend...) - else - _add_transform_multi!(tdst, tsrc, p, src, transformer, buffers, α, β, backend...) - end + cp = TO.allocator_checkpoint!(allocator) + + # buffers have to be created without race condition: err on the side of caution with a lock + buffer_lock = Threads.ReentrantLock() + + OhMyThreads.@tasks for subtransformer in transformer.data + # setup + OhMyThreads.@set scheduler = scheduler + U, (sz_dst, structs_dst), (sz_src, structs_src) = subtransformer + + if length(U) == 1 # Degenerate block with a single tree: no matmul needed. + coeff = only(U) + TO.tensoradd!( + StridedView(data_dst, sz_dst, only(structs_dst)...), + StridedView(data_src, sz_src, only(structs_src)...), + p, false, α * coeff, β, backend, allocator + ) + else # Multi-tree block: pack → recoupling matmul → unpack. + rows, cols = size(U) + blocksize = prod(sz_src) + buffer = @lock buffer_lock TO.tensoralloc(typeof(data_dst), blocksize * (rows + cols), Val(true), allocator) + ptriv = (ntuple(identity, length(sz_src)), ()) + buffer_dst = StridedView(buffer, (blocksize, rows), (1, blocksize), 0) + buffer_src = StridedView(buffer, (blocksize, cols), (1, blocksize), blocksize * rows) + + # 1. Extract: copy each source block into column i of buffer_src as a flat vector, + # using a trivial permutation so the layout is canonical before the matmul. + @inbounds for (i, struct_src_i) in enumerate(structs_src) + TO.tensoradd!( + sreshape(buffer_src[:, i], sz_src), StridedView(data_src, sz_src, struct_src_i...), + ptriv, false, One(), Zero(), backend, allocator + ) end - end - end - return nothing -end -# specialization in the case of TensorMap -function add_kernel_threaded!( - ::FusionStyle, tdst, tsrc, p, transformer::GenericTreeTransformer, α, β, backend...; - ntasks::Int = get_num_transformer_threads() - ) - nblocks = length(transformer.data) - - counter = Threads.Atomic{Int}(1) - Threads.@sync for _ in 1:min(ntasks, nblocks) - Threads.@spawn begin - # preallocate buffers for each task - buffers = allocate_buffers(tdst, tsrc, transformer) - - while true - local_counter = Threads.atomic_add!(counter, 1) - local_counter > nblocks && break - @inbounds subtransformer = transformer.data[local_counter] - if length(subtransformer[1]) == 1 - _add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) - else - _add_transform_multi!(tdst, tsrc, p, subtransformer, buffers, α, β, backend...) - end + # 2. Recoupling: buffer_dst = buffer_src * U^T (each output tree is a linear + # combination of input trees weighted by the recoupling coefficients). + U′ = Adapt.adapt(typeof(data_dst), StridedView(U)) + mul!(buffer_dst, buffer_src, transpose(U′)) + + # 3. Insert: scatter column i of buffer_dst into the destination, applying the + # actual index permutation p in the same tensoradd! call. + @inbounds for (i, struct_dst_i) in enumerate(structs_dst) + TO.tensoradd!( + StridedView(data_dst, sz_dst, struct_dst_i...), sreshape(buffer_dst[:, i], sz_src), + p, false, α, β, backend, allocator + ) end + @lock buffer_lock TO.tensorfree!(buffer, allocator) end end - - return nothing -end -# ambiguity resolution -function add_kernel_threaded!( - ::UniqueFusion, tdst, tsrc, p, transformer::GenericTreeTransformer, α, β, backend...; - ntasks::Int = get_num_transformer_threads() - ) - throw(ArgumentError("Cannot combine `GenericTreeTransformer` with `UniqueFusion`")) -end - - -# Auxiliary methods -# ----------------- -function _add_transform_single!(tdst, tsrc, p, (f₁, f₂)::FusionTreePair, transformer, α, β, backend...) - (f₁′, f₂′), coeff = transformer((f₁, f₂)) - @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, backend...) - return nothing -end -function _add_transform_single!(tdst, tsrc, p, src::FusionTreeBlock, transformer, α, β, backend...) - dst, U = transformer(src) - f₁, f₂ = only(fusiontrees(src)) - f₁′, f₂′ = only(fusiontrees(dst)) - coeff = only(U) - @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, backend...) - return nothing -end -function _add_transform_single!( - tdst, tsrc, p, (coeff, struct_dst, struct_src)::AbelianTransformerData, - α, β, backend... - ) - subblock_dst = StridedView(tdst.data, struct_dst...) - subblock_src = StridedView(tsrc.data, struct_src...) - TO.tensoradd!(subblock_dst, subblock_src, p, false, α * coeff, β, backend...) - return nothing -end -function _add_transform_single!( - tdst, tsrc, p, (basistransform, structs_dst, structs_src)::GenericTransformerData, - α, β, backend... - ) - struct_dst = (structs_dst[1], only(structs_dst[2])...) - struct_src = (structs_src[1], only(structs_src[2])...) - coeff = only(basistransform) - _add_transform_single!(tdst, tsrc, p, (coeff, struct_dst, struct_src), α, β, backend...) - return nothing -end - -function _add_transform_multi!(tdst, tsrc, p, src::FusionTreeBlock, transformer, (buffer1, buffer2), α, β, backend...) - dst, U = transformer(src) - rows, cols = size(U) - sz_src = size(tsrc[first(fusiontrees(src))...]) - blocksize = prod(sz_src) - matsize = ( - prod(TupleTools.getindices(sz_src, codomainind(tsrc))), - prod(TupleTools.getindices(sz_src, domainind(tsrc))), - ) - - # Filling up a buffer with contiguous data - buffer_src = StridedView(buffer2, (blocksize, cols), (1, blocksize), 0) - for (i, (f₁, f₂)) in enumerate(fusiontrees(src)) - subblock_src = sreshape(tsrc[f₁, f₂], matsize) - bufblock_src = sreshape(buffer_src[:, i], matsize) - copy!(bufblock_src, subblock_src) - end - - # Resummation into a second buffer using BLAS - buffer_dst = StridedView(buffer1, (blocksize, rows), (1, blocksize), 0) - mul!(buffer_dst, buffer_src, transpose(StridedView(U)), α, Zero()) - - # Filling up the output - for (i, (f₃, f₄)) in enumerate(fusiontrees(dst)) - subblock_dst = tdst[f₃, f₄] - bufblock_dst = sreshape(buffer_dst[:, i], sz_src) - TO.tensoradd!(subblock_dst, bufblock_dst, p, false, One(), β, backend...) - end - - return nothing -end -function _add_transform_multi!( - tdst, tsrc, p, (U, (sz_dst, structs_dst), (sz_src, structs_src)), - (buffer1, buffer2), α, β, backend... - ) - rows, cols = size(U) - blocksize = prod(sz_src) - matsize = ( - prod(TupleTools.getindices(sz_src, codomainind(tsrc))), - prod(TupleTools.getindices(sz_src, domainind(tsrc))), - ) - - # Filling up a buffer with contiguous data - buffer_src = StridedView(buffer2, (blocksize, cols), (1, blocksize), 0) - for (i, struct_src) in enumerate(structs_src) - subblock_src = sreshape(StridedView(tsrc.data, sz_src, struct_src...), matsize) - bufblock_src = sreshape(buffer_src[:, i], matsize) - copy!(bufblock_src, subblock_src) - end - - # Resummation into a second buffer using BLAS - buffer_dst = StridedView(buffer1, (blocksize, rows), (1, blocksize), 0) - mul!(buffer_dst, buffer_src, transpose(StridedView(U)), α, Zero()) - - # Filling up the output - for (i, struct_dst) in enumerate(structs_dst) - subblock_dst = StridedView(tdst.data, sz_dst, struct_dst...) - bufblock_dst = sreshape(buffer_dst[:, i], sz_src) - TO.tensoradd!(subblock_dst, bufblock_dst, p, false, One(), β, backend...) - end - + TO.allocator_reset!(allocator, cp) return nothing end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index bd6609163..64ae90a69 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -554,3 +554,8 @@ function Base.promote_rule( A = promote_storagetype(VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂)), TT₁, TT₂) return tensormaptype(S, N₁, N₂, A) end + +function Adapt.adapt_structure(to, x::TensorMap) + data = Adapt.adapt(to, x.data) + return TensorMap{eltype(data)}(data, space(x)) +end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 3fc79cf0c..375b63768 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -43,9 +43,9 @@ function TO.tensoradd!( if conjA A′ = adjoint(A) pA′ = adjointtensorindices(A, _canonicalize(pA, C)) - add_permute!(C, A′, pA′, α, β, backend) + permute!(C, A′, pA′, α, β, backend) else - add_permute!(C, A, _canonicalize(pA, C), α, β, backend) + permute!(C, A, _canonicalize(pA, C), α, β, backend) end return C end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 30ec1de0f..f68378cc2 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -128,25 +128,6 @@ function repack_transformer_structure(structures::Dictionary, trees) return sz, strides_offsets end -function buffersize(transformer::GenericTreeTransformer) - return maximum(transformer.data; init = 0) do (basistransform, structures_dst, _) - return prod(structures_dst[1]) * size(basistransform, 1) - end -end - -function allocate_buffers( - tdst::TensorMap, tsrc::TensorMap, transformer::GenericTreeTransformer - ) - sz = buffersize(transformer) - return similar(tdst.data, sz), similar(tsrc.data, sz) -end -function allocate_buffers( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, transformer - ) - # be pessimistic and assume the worst for now - sz = dim(space(tsrc)) - return similar(storagetype(tdst), sz), similar(storagetype(tsrc), sz) -end function treetransformertype(Vdst, Vsrc) I = sectortype(Vdst) @@ -185,22 +166,17 @@ end return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc) end -for (transform, treetransformer) in - ((:permute, :treepermuter), (:transpose, :treetransposer)) - @eval begin - function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple) - return fusiontreetransform(f) = $transform(f, p) - end - function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple) - return $treetransformer(space(tdst), space(tsrc), p) - end - @cached function $treetransformer( - Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple - )::treetransformertype(Vdst, Vsrc) - fusiontreetransform(f) = $transform(f, p) - return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc) - end - end +function treetransposer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple) + return fusiontreetransform(f) = transpose(f, p) +end +function treetransposer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple) + return treetransposer(space(tdst), space(tsrc), p) +end +@cached function treetransposer( + Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple + )::treetransformertype(Vdst, Vsrc) + fusiontreetransform(f) = transpose(f, p) + return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc) end # default cachestyle is GlobalLRUCache @@ -227,3 +203,9 @@ end function _transformer_weight((mat, structs_dst, structs_src)::GenericTransformerData) return length(mat) * prod(structs_dst[1]) end + +function buffersize(transformer::GenericTreeTransformer) + return maximum(transformer.data; init = 0) do (basistransform, structures_dst, _) + return prod(structures_dst[1]) * size(basistransform, 1) + end +end diff --git a/test/factorizations/svd.jl b/test/factorizations/svd.jl index 0678db827..e2eb19076 100644 --- a/test/factorizations/svd.jl +++ b/test/factorizations/svd.jl @@ -49,9 +49,9 @@ for V in spacelist end for T in eltypes, t in (randn(T, W, W), randn(T, W, W)') project_hermitian!(t) - vals = @constinferred LinearAlgebra.eigvals(t) - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) + vals = @constinferred eigh_vals(t) + λmax = maximum(abs, vals) + λmin = minimum(abs, vals) @test cond(t) ≈ λmax / λmin end end diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index 4dd6413cf..390721d71 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -18,7 +18,7 @@ eltypes = (Float64, ComplexF64) hasbraiding = BraidingStyle(sectortype(eltype(V))) isa HasBraiding symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding - symmetricbraiding && @timedtestset "add_permute!" begin + symmetricbraiding && @timedtestset "permute!" begin A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') α = randn(T) β = randn(T) @@ -27,12 +27,12 @@ eltypes = (Float64, ComplexF64) for _ in 1:5 p = randindextuple(numind(A)) C = randn!(permute(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.permute!, C, A, p, α, β; atol, rtol, mode) A = C end end - @timedtestset "add_transpose!" begin + @timedtestset "transpose!" begin A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') α = randn(T) β = randn(T) @@ -41,18 +41,18 @@ eltypes = (Float64, ComplexF64) for _ in 1:2 p = randcircshift(numout(A), numin(A)) C = randn!(transpose(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, One(), Zero(); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.transpose!, C, A, p, One(), Zero(); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.transpose!, C, A, p, α, β; atol, rtol, mode) if !(T <: Real) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, real(A), p, α, β; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, real(α), β; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, real(A), p, real(α), β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.transpose!, C, real(A), p, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.transpose!, C, A, p, real(α), β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.transpose!, C, real(A), p, real(α), β; atol, rtol, mode) end A = C end end - hasbraiding && @timedtestset "add_braid!" begin + hasbraiding && @timedtestset "braid!" begin A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') α = randn(T) β = randn(T) @@ -62,11 +62,11 @@ eltypes = (Float64, ComplexF64) p = randcircshift(numout(A), numin(A)) levels = Tuple(randperm(numind(A))) C = randn!(transpose(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_braid!, C, A, p, levels, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.braid!, C, A, p, levels, α, β; atol, rtol, mode) if !(T <: Real) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_braid!, C, real(A), p, levels, α, β; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_braid!, C, A, p, levels, real(α), β; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_braid!, C, A, p, levels, real(α), real(β); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.braid!, C, real(A), p, levels, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.braid!, C, A, p, levels, real(α), β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.braid!, C, A, p, levels, real(α), real(β); atol, rtol, mode) end A = C end diff --git a/test/tensors/indexmanipulations.jl b/test/tensors/indexmanipulations.jl index c38b182a2..836418b3f 100644 --- a/test/tensors/indexmanipulations.jl +++ b/test/tensors/indexmanipulations.jl @@ -129,6 +129,14 @@ for V in spacelist @tensor tb[a, b] := flip(t1, (1, 3))[x, y, a, z] * flip(t2, (2, 4))[y, b, z, x] @test flip(ta, (1, 2)) ≈ tb end + hasbraiding && !symmetricbraiding && @timedtestset "Braid AdjointTensorMap: adjoint identity" begin + t = rand(ComplexF64, V1 ⊗ V2 ← V3) + p = ((2,), (1, 3)) + levels = (1, 3, 2) + t1 = copy(braid(t', p, levels)) + t2 = braid(copy(t'), p, levels) + @test t1 ≈ t2 + end end TensorKit.empty_globalcaches!() end