Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/grad_vector.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import QuantumControl.QuantumPropagators: _exp_prop_convert_state
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
import QuantumControl.QuantumPropagators.Interfaces:
supports_inplace, supports_vector_interface


@doc raw"""Extended state-vector for the dynamic gradient.
Expand Down Expand Up @@ -68,8 +69,8 @@ in-place operations.

Returns `Ψ̃`.
"""
function resetgradvec!(Ψ̃::GradVector)
if supports_inplace(Ψ̃)
function resetgradvec!(Ψ̃::T) where {T<:GradVector}
if supports_inplace(T)
for i in eachindex(Ψ̃.grad_states)
fill!(Ψ̃.grad_states[i], 0.0)
end
Expand All @@ -89,4 +90,7 @@ end

_exp_prop_convert_state(::GradVector) = Vector{ComplexF64}

supports_inplace(Ψ̃::GradVector) = supports_inplace(Ψ̃.state)
supports_inplace(::Type{GradVector{N,T}}) where {N,T} = supports_inplace(T)

supports_vector_interface(::Type{GradVector{N,T}}) where {N,T} =
supports_vector_interface(T)
9 changes: 7 additions & 2 deletions src/gradgen_operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using Random: GLOBAL_RNG
import QuantumControl.QuantumPropagators: _exp_prop_convert_operator
import QuantumControl.QuantumPropagators.Controls: get_controls
import QuantumControl.QuantumPropagators.SpectralRange: random_state
import QuantumControl.QuantumPropagators.Interfaces: supports_inplace
import QuantumControl.QuantumPropagators.Interfaces:
supports_inplace, supports_matrix_interface


"""Static generator for the dynamic gradient.
Expand Down Expand Up @@ -40,4 +41,8 @@ end

_exp_prop_convert_operator(::GradgenOperator) = Matrix{ComplexF64}

supports_inplace(::GradgenOperator) = true
supports_inplace(::Type{GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
(supports_inplace(GT) && supports_inplace(CGT))

supports_matrix_interface(::Type{<:GradgenOperator{N,GT,CGT}}) where {N,GT,CGT} =
supports_matrix_interface(GT) && supports_matrix_interface(CGT)
262 changes: 245 additions & 17 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector,
end


function LinearAlgebra.mul!(Φ::GradVector, G::GradgenOperator, Ψ::GradVector)
return LinearAlgebra.mul!(Φ, G, Ψ, true, false)
end


function LinearAlgebra.lmul!(c, Ψ::GradVector)
LinearAlgebra.lmul!(c, Ψ.state)
for i ∈ eachindex(Ψ.grad_states)
Expand Down Expand Up @@ -48,6 +53,11 @@ function LinearAlgebra.dot(Ψ::GradVector, Φ::GradVector)
end


function LinearAlgebra.dot(Ψ::GradVector, G::GradgenOperator, Φ::GradVector)
return LinearAlgebra.dot(Ψ, G * Φ)
end


LinearAlgebra.ishermitian(G::GradgenOperator) = false


Expand All @@ -70,41 +80,126 @@ function Base.copy(Ψ::GradVector{num_controls,T}) where {num_controls,T}
end


function Base.length(Ψ::GradVector)
# === Vector interface for GradVector ===
#
# The following methods are part of the vector interface and are only
# meaningful when `supports_vector_interface` is true for the state type T.
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
# the Val{true} method contains the implementation, and the Val{false} method
# throws an error.

function _length(::Val{true}, Ψ::GradVector)
return length(Ψ.state) * (1 + length(Ψ.grad_states))
end

function _length(::Val{false}, Ψ::GradVector)
error("$(typeof(Ψ)) does not support the vector interface")
end

function Base.size(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
return (num_controls + 1) .* size(O.G)
function Base.length(Ψ::T) where {T<:GradVector}
return _length(Val(supports_vector_interface(T)), Ψ)
end


function Base.size(
O::GradgenOperator{num_controls,GT,CGT},
dim::Integer
) where {num_controls,GT,CGT}
return (num_controls + 1) * size(O.G, dim)
function _size(::Val{true}, Ψ::GradVector{num_controls,T}) where {num_controls,T}
return ((num_controls + 1) * length(Ψ.state),)
end

function _size(::Val{false}, Ψ::GradVector)
error("$(typeof(Ψ)) does not support the vector interface")
end

function Base.similar(Ψ::GradVector{num_controls,T}) where {num_controls,T}
return GradVector{num_controls,T}(similar(Ψ.state), [similar(ϕ) for ϕ ∈ Ψ.grad_states])
function Base.size(Ψ::T) where {T<:GradVector}
return _size(Val(supports_vector_interface(T)), Ψ)
end

function Base.similar(G::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
return GradgenOperator{num_controls,GT,CGT}(similar(G.G), similar(G.control_deriv_ops))

function _getindex(
::Val{true},
Ψ::GradVector{num_controls,T},
k::Int
) where {num_controls,T}
N = length(Ψ.state)
L = num_controls
block = (k - 1) ÷ N + 1
local_k = (k - 1) % N + 1
if block <= L
return Ψ.grad_states[block][local_k]
else
return Ψ.state[local_k]
end
end

function Base.eltype(O::GradgenOperator{num_controls,GT,CGT}) where {num_controls,GT,CGT}
return promote_type(eltype(GT), eltype(CGT))
function _getindex(::Val{false}, Ψ::GradVector, k::Int)
error("$(typeof(Ψ)) does not support the vector interface")
end

function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
copyto!(dest.G, src.G)
copyto!(dest.control_deriv_ops, src.control_deriv_ops)
function Base.getindex(Ψ::T, k::Int) where {T<:GradVector}
return _getindex(Val(supports_vector_interface(T)), Ψ, k)
end


function _setindex!(
::Val{true},
Ψ::GradVector{num_controls,T},
v,
k::Int
) where {num_controls,T}
N = length(Ψ.state)
L = num_controls
block = (k - 1) ÷ N + 1
local_k = (k - 1) % N + 1
if block <= L
Ψ.grad_states[block][local_k] = v
else
Ψ.state[local_k] = v
end
return Ψ
end

function _setindex!(::Val{false}, Ψ::GradVector, v, k::Int)
error("$(typeof(Ψ)) does not support the vector interface")
end

function Base.setindex!(Ψ::T, v, k::Int) where {T<:GradVector}
return _setindex!(Val(supports_vector_interface(T)), Ψ, v, k)
end


function _iterate(::Val{true}, Ψ::GradVector, k)
k > length(Ψ) && return nothing
return (Ψ[k], k + 1)
end

function _iterate(::Val{false}, Ψ::GradVector, k)
error("$(typeof(Ψ)) does not support the vector interface")
end

function Base.iterate(Ψ::T, k = 1) where {T<:GradVector}
return _iterate(Val(supports_vector_interface(T)), Ψ, k)
end


function Base.similar(Ψ::GradVector{num_controls,T}) where {num_controls,T}
state_sim = similar(Ψ.state)
grad_states_sim = [similar(ϕ) for ϕ ∈ Ψ.grad_states]
return GradVector{num_controls,typeof(state_sim)}(state_sim, grad_states_sim)
end

# similar(Ψ, S) calls length(Ψ), which will error if !supports_vector_interface
Base.similar(Ψ::GradVector, ::Type{S}) where {S} = Vector{S}(undef, length(Ψ))

# similar(Ψ, dims) calls eltype(Ψ) but not length/size, so no vector interface needed
Base.similar(Ψ::GradVector, dims::Tuple{Vararg{Int}}) = Array{eltype(Ψ)}(undef, dims)

# These definitions of `similar` exist to make ExponentialUtilities happy, but
# it's not clear at all that `similar` with a custom shape really makes sense
Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int,Int}) where {T} =
Matrix{T}(undef, dims...)

Base.similar(::GradVector, ::Type{T}, dims::Tuple{Int}) where {T} =
Vector{T}(undef, dims[1])


function Base.fill!(Ψ::GradVector, v)
Base.fill!(Ψ.state, v)
Expand All @@ -115,6 +210,139 @@ function Base.fill!(Ψ::GradVector, v)
end


# === Matrix interface for GradgenOperator ===
#
# The following methods are part of the matrix interface and are only
# meaningful when `supports_matrix_interface` is true for both component types.
# Each method delegates to a private `_name(::Val{supports}, ...)` function:
# the Val{true} method contains the implementation, and the Val{false} method
# throws an error.

function _size(
::Val{true},
O::GradgenOperator{num_controls,GT,CGT}
) where {num_controls,GT,CGT}
return (num_controls + 1) .* size(O.G)
end

function _size(::Val{false}, O::GradgenOperator)
error("$(typeof(O)) does not support the matrix interface")
end

function Base.size(O::T) where {T<:GradgenOperator}
return _size(Val(supports_matrix_interface(T)), O)
end


function _size(
::Val{true},
O::GradgenOperator{num_controls,GT,CGT},
dim::Integer
) where {num_controls,GT,CGT}
return (num_controls + 1) * size(O.G, dim)
end

function _size(::Val{false}, O::GradgenOperator, dim::Integer)
error("$(typeof(O)) does not support the matrix interface")
end

function Base.size(O::T, dim::Integer) where {T<:GradgenOperator}
return _size(Val(supports_matrix_interface(T)), O, dim)
end


# As for an `Operator`, we implement `similar` to return a standard `Array`
# because `GradgenOperator` does not `setindex!`, so it's arguably not a
# "mutable array" even if its components are mutable.
# similar(O) and similar(O, S) call size(O), which will error if
# !supports_matrix_interface. The dims-based variants need no guard.
Base.similar(G::GradgenOperator) = Array{eltype(G)}(undef, size(G))

Base.similar(O::GradgenOperator, ::Type{S}) where {S} = Array{S}(undef, size(O))
Base.similar(O::GradgenOperator, dims::Tuple{Vararg{Int}}) = Array{eltype(O)}(undef, dims)
Base.similar(O::GradgenOperator, ::Type{S}, dims::Tuple{Vararg{Int}}) where {S} =
Array{S}(undef, dims)


function Base.eltype(
::Type{GradgenOperator{num_controls,GT,CGT}}
) where {num_controls,GT,CGT}
return promote_type(eltype(GT), eltype(CGT))
end


function _getindex(
::Val{true},
O::GradgenOperator{num_controls,GT,CGT},
row::Int,
col::Int
) where {num_controls,GT,CGT}
T = eltype(O)
N, M = size(O.G)
L = num_controls
block_row = (row - 1) ÷ N + 1
block_col = (col - 1) ÷ M + 1
local_row = (row - 1) % N + 1
local_col = (col - 1) % M + 1
if block_row == block_col
return convert(T, O.G[local_row, local_col])
elseif block_col == L + 1 && block_row <= L
return convert(T, O.control_deriv_ops[block_row][local_row, local_col])
else
return zero(T)
end
end

function _getindex(::Val{false}, O::GradgenOperator, row::Int, col::Int)
error("$(typeof(O)) does not support the matrix interface")
end

function Base.getindex(O::T, row::Int, col::Int) where {T<:GradgenOperator}
return _getindex(Val(supports_matrix_interface(T)), O, row, col)
end


function _length(::Val{true}, O::GradgenOperator)
return prod(size(O))
end

function _length(::Val{false}, O::GradgenOperator)
error("$(typeof(O)) does not support the matrix interface")
end

function Base.length(O::T) where {T<:GradgenOperator}
return _length(Val(supports_matrix_interface(T)), O)
end


function _iterate(::Val{true}, O::GradgenOperator, k)
n = length(O)
k > n && return nothing
n_rows = size(O, 1)
i = (k - 1) % n_rows + 1
j = (k - 1) ÷ n_rows + 1
return (O[i, j], k + 1)
end

function _iterate(::Val{false}, O::GradgenOperator, k)
error("$(typeof(O)) does not support the matrix interface")
end

function Base.iterate(O::T, k = 1) where {T<:GradgenOperator}
return _iterate(Val(supports_matrix_interface(T)), O, k)
end


function Base.eltype(::Type{GradVector{num_controls,T}}) where {num_controls,T}
return eltype(T)
end

function Base.copyto!(dest::GradgenOperator, src::GradgenOperator)
copyto!(dest.G, src.G)
copyto!(dest.control_deriv_ops, src.control_deriv_ops)
end


function Base.zero(Ψ::GradVector{num_controls,T}) where {num_controls,T}
return GradVector{num_controls,T}(zero(Ψ.state), [zero(ϕ) for ϕ ∈ Ψ.grad_states])
end
Expand Down
Loading