Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 157 additions & 35 deletions src/linearcombinations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ GaussianLinearCombination with 2 terms for 1 mode.
[1] 0.6 * GaussianState
[2] 0.8 * GaussianState
"""
const GaussianObject = Union{GaussianState,GaussianUnitary,GaussianChannel}

_is_gaussian_family(::Type{T}) where {T} = T <: GaussianState || T <: GaussianUnitary || T <: GaussianChannel
_gaussian_family(::Type{T}) where {T<:GaussianState} = GaussianState
_gaussian_family(::Type{T}) where {T<:GaussianUnitary} = GaussianUnitary
_gaussian_family(::Type{T}) where {T<:GaussianChannel} = GaussianChannel

function _gaussian_coeff_type(x::GaussianState)
return float(real(promote_type(eltype(x.mean), eltype(x.covar))))
end
function _gaussian_coeff_type(x::GaussianUnitary)
return float(real(promote_type(eltype(x.disp), eltype(x.symplectic))))
end
function _gaussian_coeff_type(x::GaussianChannel)
return float(real(promote_type(eltype(x.disp), eltype(x.transform), eltype(x.noise))))
end

mutable struct GaussianLinearCombination{B<:SymplecticBasis,C,S}
basis::B
coeffs::Vector{C}
Expand All @@ -57,14 +74,16 @@ mutable struct GaussianLinearCombination{B<:SymplecticBasis,C,S}
function GaussianLinearCombination(basis::B, coeffs::Vector{C}, states::Vector{S}) where {B<:SymplecticBasis,C,S}
length(coeffs) == length(states) || throw(DimensionMismatch("Number of coefficients ($(length(coeffs))) must match number of states ($(length(states)))"))
isempty(states) && throw(ArgumentError("Cannot create an empty linear combination"))
_is_gaussian_family(S) || throw(ArgumentError("Linear combinations only support GaussianState, GaussianUnitary, or GaussianChannel elements"))
ħ = first(states).ħ
@inbounds for (i, state) in enumerate(states)
state isa GaussianState || throw(ArgumentError("Element $i is not a GaussianState: got $(typeof(state))"))
if state.basis != basis
throw(ArgumentError("State $i has incompatible basis: expected $(typeof(basis))($(basis.nmodes)), got $(typeof(state.basis))($(state.basis.nmodes))"))
family = _gaussian_family(S)
@inbounds for (i, component) in enumerate(states)
component isa family || throw(ArgumentError("Element $i is not a $(family): got $(typeof(component))"))
if component.basis != basis
throw(ArgumentError("Element $i has incompatible basis: expected $(typeof(basis))($(basis.nmodes)), got $(typeof(component.basis))($(component.basis.nmodes))"))
end
if state.ħ != ħ
throw(ArgumentError("State $i has different ħ: expected $ħ, got $(state.ħ)"))
if component.ħ != ħ
throw(ArgumentError("Element $i has different ħ: expected $ħ, got $(component.ħ)"))
end
end
return new{B,C,S}(basis, coeffs, states, ħ)
Expand All @@ -75,8 +94,8 @@ end

Create a linear combination containing a single Gaussian state with coefficient 1.0.
"""
function GaussianLinearCombination(state::GaussianState{B,M,V}) where {B,M,V}
coeff_type = float(real(promote_type(eltype(M), eltype(V))))
function GaussianLinearCombination(state::S) where {S<:GaussianObject}
coeff_type = _gaussian_coeff_type(state)
return GaussianLinearCombination(state.basis, [one(coeff_type)], [state])
end
"""
Expand All @@ -88,28 +107,30 @@ function GaussianLinearCombination(pairs::Vector{<:Tuple})
isempty(pairs) && throw(ArgumentError("Cannot create an empty linear combination"))
coeffs = [convert(Number, p[1]) for p in pairs]
states = [p[2] for p in pairs]
first_state_type = typeof(first(states))
_is_gaussian_family(first_state_type) || throw(ArgumentError("Element 1: second element must be a GaussianState, GaussianUnitary, or GaussianChannel"))
@inbounds for (i, state) in enumerate(states)
state isa GaussianState || throw(ArgumentError("Element $i: second element must be a GaussianState"))
state isa first_state_type || throw(ArgumentError("Element $i: second element type does not match element 1"))
end
basis = first(states).basis
return GaussianLinearCombination(basis, coeffs, states)
end
"""
GaussianLinearCombination(coeffs::Vector{<:Number}, states::Vector{<:GaussianState})
GaussianLinearCombination(coeffs::Vector{<:Number}, states::Vector{<:GaussianObject})

Create a linear combination from separate vectors of coefficients and states.
Create a linear combination from separate vectors of coefficients and Gaussian components.
"""
function GaussianLinearCombination(coeffs::Vector{<:Number}, states::Vector{<:GaussianState})
function GaussianLinearCombination(coeffs::Vector{<:Number}, states::Vector{<:GaussianObject})
isempty(states) && throw(ArgumentError("Cannot create an empty linear combination"))
basis = first(states).basis
return GaussianLinearCombination(basis, coeffs, states)
end
"""
GaussianLinearCombination(pairs::Pair{<:Number,<:GaussianState}...)
GaussianLinearCombination(pairs::Pair{<:Number,<:GaussianObject}...)

Create a linear combination from coefficient => state pairs.
Create a linear combination from coefficient => Gaussian component pairs.
"""
function GaussianLinearCombination(pairs::Pair{<:Number,<:GaussianState}...)
function GaussianLinearCombination(pairs::Pair{<:Number,<:GaussianObject}...)
isempty(pairs) && throw(ArgumentError("Cannot create an empty linear combination"))
coeffs = [p.first for p in pairs]
states = [p.second for p in pairs]
Expand All @@ -127,6 +148,7 @@ Add two linear combinations of Gaussian states. Both must have the same symplect
function Base.:+(lc1::GaussianLinearCombination{B}, lc2::GaussianLinearCombination{B}) where {B<:SymplecticBasis}
lc1.basis == lc2.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc1.ħ == lc2.ħ || throw(ArgumentError(HBAR_ERROR))
_gaussian_family(typeof(first(lc1.states))) == _gaussian_family(typeof(first(lc2.states))) || throw(ArgumentError("Cannot add linear combinations with different Gaussian component families"))
coeffs = vcat(lc1.coeffs, lc2.coeffs)
states = vcat(lc1.states, lc2.states)
return GaussianLinearCombination(lc1.basis, coeffs, states)
Expand Down Expand Up @@ -172,16 +194,16 @@ Base.:*(lc::GaussianLinearCombination, α::Number) = α * lc

Multiply a Gaussian state by a scalar to create a linear combination.
"""
function Base.:*(α::Number, state::GaussianState)
coeff_type = promote_type(typeof(α), eltype(state.mean), eltype(state.covar))
function Base.:*(α::Number, state::S) where {S<:GaussianObject}
coeff_type = promote_type(typeof(α), _gaussian_coeff_type(state))
return GaussianLinearCombination(state.basis, [convert(coeff_type, α)], [state])
end
"""
*(state::GaussianState, α::Number)

Multiply a Gaussian state by a scalar to create a linear combination.
"""
Base.:*(state::GaussianState, α::Number) = α * state
Base.:*(state::S, α::Number) where {S<:GaussianObject} = α * state
"""
+(state1::GaussianState, state2::GaussianState)

Expand All @@ -194,32 +216,82 @@ function Base.:+(state1::GaussianState, state2::GaussianState)
eltype(state2.mean), eltype(state2.covar))
return GaussianLinearCombination(state1.basis, [one(coeff_type), one(coeff_type)], [state1, state2])
end

function Base.:+(op1::GaussianUnitary, op2::GaussianUnitary)
op1.basis == op2.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
op1.ħ == op2.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(op1), _gaussian_coeff_type(op2))
return GaussianLinearCombination(op1.basis, [one(coeff_type), one(coeff_type)], [op1, op2])
end

function Base.:+(ch1::GaussianChannel, ch2::GaussianChannel)
ch1.basis == ch2.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
ch1.ħ == ch2.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(ch1), _gaussian_coeff_type(ch2))
return GaussianLinearCombination(ch1.basis, [one(coeff_type), one(coeff_type)], [ch1, ch2])
end
"""
+(state::GaussianState, lc::GaussianLinearCombination)

Add a Gaussian state to a linear combination.
"""
function Base.:+(state::GaussianState, lc::GaussianLinearCombination)
function Base.:+(state::GaussianState, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState})
state.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
state.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(eltype(state.mean), eltype(state.covar), eltype(lc.coeffs))
new_coeffs = vcat(one(coeff_type), convert(Vector{coeff_type}, lc.coeffs))
new_states = vcat(state, lc.states)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:+(op::GaussianUnitary, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianUnitary})
op.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
op.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(op), eltype(lc.coeffs))
new_coeffs = vcat(one(coeff_type), convert(Vector{coeff_type}, lc.coeffs))
new_states = vcat(op, lc.states)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:+(ch::GaussianChannel, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianChannel})
ch.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
ch.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(ch), eltype(lc.coeffs))
new_coeffs = vcat(one(coeff_type), convert(Vector{coeff_type}, lc.coeffs))
new_states = vcat(ch, lc.states)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end
"""
+(lc::GaussianLinearCombination, state::GaussianState)

Add a linear combination to a Gaussian state.
"""
function Base.:+(lc::GaussianLinearCombination, state::GaussianState)
function Base.:+(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState}, state::GaussianState)
lc.basis == state.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == state.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(eltype(state.mean), eltype(state.covar), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), one(coeff_type))
new_states = vcat(lc.states, state)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:+(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianUnitary}, op::GaussianUnitary)
lc.basis == op.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == op.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(op), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), one(coeff_type))
new_states = vcat(lc.states, op)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:+(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianChannel}, ch::GaussianChannel)
lc.basis == ch.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == ch.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(ch), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), one(coeff_type))
new_states = vcat(lc.states, ch)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end
"""
-(state1::GaussianState, state2::GaussianState)

Expand All @@ -232,35 +304,79 @@ function Base.:-(state1::GaussianState, state2::GaussianState)
eltype(state2.mean), eltype(state2.covar))
return GaussianLinearCombination(state1.basis, [one(coeff_type), -one(coeff_type)], [state1, state2])
end

function Base.:-(op1::GaussianUnitary, op2::GaussianUnitary)
op1.basis == op2.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
op1.ħ == op2.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(op1), _gaussian_coeff_type(op2))
return GaussianLinearCombination(op1.basis, [one(coeff_type), -one(coeff_type)], [op1, op2])
end

function Base.:-(ch1::GaussianChannel, ch2::GaussianChannel)
ch1.basis == ch2.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
ch1.ħ == ch2.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(ch1), _gaussian_coeff_type(ch2))
return GaussianLinearCombination(ch1.basis, [one(coeff_type), -one(coeff_type)], [ch1, ch2])
end
"""
-(state::GaussianState, lc::GaussianLinearCombination)

Subtract a linear combination from a Gaussian state.
"""
function Base.:-(state::GaussianState, lc::GaussianLinearCombination)
function Base.:-(state::GaussianState, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState})
state.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
state.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
return state + (-1) * lc
end

function Base.:-(op::GaussianUnitary, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianUnitary})
op.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
op.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
return op + (-1) * lc
end

function Base.:-(ch::GaussianChannel, lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianChannel})
ch.basis == lc.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
ch.ħ == lc.ħ || throw(ArgumentError(HBAR_ERROR))
return ch + (-1) * lc
end
"""
-(lc::GaussianLinearCombination, state::GaussianState)

Subtract a Gaussian state from a linear combination.
"""
function Base.:-(lc::GaussianLinearCombination, state::GaussianState)
function Base.:-(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState}, state::GaussianState)
lc.basis == state.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == state.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(eltype(state.mean), eltype(state.covar), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), -one(coeff_type))
new_states = vcat(lc.states, state)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:-(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianUnitary}, op::GaussianUnitary)
lc.basis == op.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == op.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(op), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), -one(coeff_type))
new_states = vcat(lc.states, op)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end

function Base.:-(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianChannel}, ch::GaussianChannel)
lc.basis == ch.basis || throw(ArgumentError(SYMPLECTIC_ERROR))
lc.ħ == ch.ħ || throw(ArgumentError(HBAR_ERROR))
coeff_type = promote_type(_gaussian_coeff_type(ch), eltype(lc.coeffs))
new_coeffs = vcat(convert(Vector{coeff_type}, lc.coeffs), -one(coeff_type))
new_states = vcat(lc.states, ch)
return GaussianLinearCombination(lc.basis, new_coeffs, new_states)
end
"""
-(state::GaussianState)

Negate a Gaussian state to create a linear combination with coefficient -1.
"""
Base.:-(state::GaussianState) = (-1) * state
Base.:-(state::S) where {S<:GaussianObject} = (-1) * state

"""
normalize!(lc::GaussianLinearCombination)
Expand Down Expand Up @@ -317,10 +433,13 @@ function simplify!(lc::GaussianLinearCombination; atol::Real=1e-14)
end
keep_mask = abs.(lc.coeffs) .> atol
if !any(keep_mask)
vac = vacuumstate(lc.basis, ħ = lc.ħ)
coeff_type = eltype(lc.coeffs)
lc.coeffs = [coeff_type(atol)]
lc.states = [vac]
lc.coeffs = [coeff_type(atol)]
if _gaussian_family(typeof(first(lc.states))) === GaussianState
lc.states = [vacuumstate(lc.basis, ħ = lc.ħ)]
else
lc.states = [first(lc.states)]
end
return lc
end
coeffs = lc.coeffs[keep_mask]
Expand All @@ -342,10 +461,13 @@ function simplify!(lc::GaussianLinearCombination; atol::Real=1e-14)
resize!(combined_coeffs, n_unique)
final_mask = abs.(combined_coeffs) .> atol
if !any(final_mask)
vac = vacuumstate(lc.basis, ħ = lc.ħ)
coeff_type = eltype(combined_coeffs)
lc.coeffs = [coeff_type(atol)]
lc.states = [vac]
if _gaussian_family(typeof(first(unique_states))) === GaussianState
lc.states = [vacuumstate(lc.basis, ħ = lc.ħ)]
else
lc.states = [first(unique_states)]
end
else
lc.coeffs = combined_coeffs[final_mask]
lc.states = unique_states[final_mask]
Expand All @@ -366,8 +488,8 @@ function Base.show(io::IO, mime::MIME"text/plain", lc::GaussianLinearCombination
println(io, " ħ = $(lc.ħ)")
max_display = min(length(lc), 5)
@inbounds for i in 1:max_display
coeff, state = lc[i]
println(io, " [$i] $(coeff) * GaussianState")
coeff, component = lc[i]
println(io, " [$i] $(coeff) * $(nameof(typeof(component)))")
end
if length(lc) > max_display
println(io, " ⋮ ($(length(lc) - max_display) more terms)")
Expand Down Expand Up @@ -624,7 +746,7 @@ end
Compute Wigner function of a linear combination including quantum interference.
`W(x) = Σᵢ |cᵢ|² Wᵢ(x) + 2 Σᵢ<ⱼ Re(cᵢ*cⱼ W_cross(ψᵢ,ψⱼ,x))`
"""
function wigner(lc::GaussianLinearCombination, x::AbstractVector)
function wigner(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState}, x::AbstractVector)
length(x) == length(lc.states[1].mean) || throw(ArgumentError(WIGNER_ERROR))
result = 0.0
@inbounds for i in 1:length(lc)
Expand Down Expand Up @@ -679,7 +801,7 @@ end

Compute Wigner characteristic function of a linear combination including interference.
"""
function wignerchar(lc::GaussianLinearCombination, xi::AbstractVector)
function wignerchar(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState}, xi::AbstractVector)
length(xi) == length(lc.states[1].mean) || throw(ArgumentError(WIGNER_ERROR))
result = 0.0 + 0.0im
@inbounds for i in 1:length(lc)
Expand All @@ -694,14 +816,14 @@ function wignerchar(lc::GaussianLinearCombination, xi::AbstractVector)
return result
end

function purity(lc::GaussianLinearCombination)
function purity(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState})
# A GaussianLinearCombination represents a pure superposition state
# |ψ⟩ = Σᵢ cᵢ|ψᵢ⟩, so purity = Tr(ρ²) = 1
return 1.0
end

function entropy_vn(lc::GaussianLinearCombination)
function entropy_vn(lc::GaussianLinearCombination{<:Any,<:Any,<:GaussianState})
# A GaussianLinearCombination represents a pure superposition state
# |ψ⟩ = Σᵢ cᵢ|ψᵢ⟩, so S(ρ) = -Tr(ρ log ρ) = 0
return 0.0
end
end
Loading
Loading