Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions ext/QuantumControlFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@ using LinearAlgebra

import FiniteDifferences
import QuantumControl.Functionals:
_default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
make_gate_chi, make_automatic_chi, make_automatic_grad_J_a


function make_automatic_chi(
J_T,
trajectories,
::Val{:FiniteDifferences};
via=_default_chi_via(trajectories)
)
function make_automatic_chi(J_T, trajectories, ::Val{:FiniteDifferences}; via=:states)

# TODO: Benchmark if χ should be closure, see QuantumControlZygoteExt.jl

Expand Down
25 changes: 16 additions & 9 deletions ext/QuantumControlZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@ using LinearAlgebra

import Zygote
import QuantumControl.Functionals:
_default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
make_gate_chi, make_automatic_chi, make_automatic_grad_J_a


function make_automatic_chi(
J_T,
trajectories,
::Val{:Zygote};
via=_default_chi_via(trajectories)
)
function make_automatic_chi(J_T, trajectories, ::Val{:Zygote}; via=:states)

# TODO: At some point, for a large system, we could benchmark if there is
# any benefit to making χ a closure and using LinearAlgebra.axpby! to
Expand All @@ -26,7 +21,14 @@ function make_automatic_chi(
χ = Vector{eltype(Ψ)}(undef, length(Ψ))
∇J = Zygote.gradient(_J_T, Ψ...)
for (k, ∇Jₖ) ∈ enumerate(∇J)
χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv
if isnothing(∇Jₖ)
# Functional does not depend on Ψₖ. That probably means a buggy
# J_T, but who knows: maybe there are situations where that
# makes sense. It would be extremely noisy to warn here.
χ[k] = zero(χ[k])
else
χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv
end
# axpby!(0.5, ∇Jₖ, false, χ[k])
end
return χ
Expand All @@ -43,7 +45,12 @@ function make_automatic_chi(
χ = Vector{eltype(Ψ)}(undef, length(Ψ))
∇J = Zygote.gradient(_J_T, τ...)
for (k, traj) ∈ enumerate(trajectories)
∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv
if isnothing(∇J[k])
# Functional does not depend on τₖ
∂J╱∂τ̄ₖ = zero(ComplexF64)
else
∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv
end
χ[k] = ∂J╱∂τ̄ₖ * traj.target_state
# axpby!(∂J╱∂τ̄ₖ, traj.target_state, false, χ[k])
end
Expand Down
93 changes: 70 additions & 23 deletions src/functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@ export make_grad_J_a, make_chi
using LinearAlgebra: axpy!, dot


# default for `via` argument of `make_chi`
function _default_chi_via(trajectories)
if any(isnothing(traj.target_state) for traj in trajectories)
return :states
else
return :tau
function _check_chi(chi; states, trajectories, tau, via)
try
if via == :tau
chi_states = chi(states, trajectories; tau)
else
chi_states = chi(states, trajectories)
end
if typeof(chi_states) ≠ typeof(states)
msg = "`chi` must return a vector of states"
error(msg)
end
catch exception
msg = "The chi generated by `make_chi` does not have the required interface"
@error msg exception
error("Cannot make chi")
end
end

Expand Down Expand Up @@ -86,15 +95,25 @@ chi = make_chi(
trajectories;
mode=:any,
automatic=:default,
via=(any(isnothing(t.target_state) for t in trajectories) ? :states : :tau),
via=:automatic, # one of :automatic, :tau, :states
)
```

creates a function `chi(Ψ, trajectories; τ)` that returns
a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩`` is the
k'th element of `Ψ`. These are the states used as the boundary condition for
the backward propagation propagation in Krotov's method and GRAPE. Each
``|χₖ⟩`` is defined as a matrix calculus
creates a function `chi(Ψ, trajectories)` or `chi(Ψ, trajectories; tau)` that
returns a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩``
is the k'th element of `Ψ`. These are the states used as the boundary condition
for the backward propagation propagation in Krotov's method and GRAPE.

The resulting `chi` function takes the keyword argument `tau`
if and only if `via=:tau` or `via=:automatic` (default) if the following
conditions are met:

* All `trajectories` have a defined `target_state` component (not `nothing`)
* `J_T` takes `tau` as a keyword argument (determined via introspection)

Both of these conditions are _requirements_ for `via=:tau`.

Each ``|χₖ⟩`` is defined as a matrix calculus
[Wirtinger derivative](https://www.ekinakyurek.me/complex-derivatives-wirtinger/),

```math
Expand Down Expand Up @@ -193,25 +212,53 @@ and the definition of the Zygote gradient with respect to a complex scalar,
gradients). Always test automatic derivatives against finite differences
and/or other automatic differentiation frameworks.
"""
function make_chi(
J_T,
trajectories;
mode=:any,
automatic=:default,
via=_default_chi_via(trajectories),
)
function make_chi(J_T, trajectories; mode=:any, automatic=:default, via=:automatic,)
states = [traj.initial_state for traj in trajectories]
tau = [zero(ComplexF64) for _ in states]
J_T_takes_tau = hasmethod(J_T, Tuple{typeof(states),typeof(trajectories)}, (:tau,))
has_target_states = all((traj.target_state ≢ nothing) for traj in trajectories)
if (via == :tau) && !J_T_takes_tau
msg = "Called `make_chi` with `via=:tau`, but given J_T does not take `tau` keyword argument"
error(msg)
end
if (via == :tau) && !has_target_states
msg = "Called `make_chi` with `via=:tau`, but not all `trajectories` define a `target_state`"
error(msg)
end
if via == :automatic
via = :states
if J_T_takes_tau && has_target_states
via = :tau
end
end
chi = nothing
try
if via == :tau
J_T_val = J_T(states, trajectories; tau)
else
J_T_val = J_T(states, trajectories)
end
if !(J_T_val isa Float64)
msg = "J_T passed to `make_chi` must return a Float64, not $(typeof(J_T_val))"
error(msg)
end
catch exception
msg = "The J_T passed to `make_chi` does not have the required interface"
@error msg exception
error("Cannot make chi")
end
if mode == :any
try
chi = make_analytic_chi(J_T, trajectories)
@debug "make_chi for J_T=$(J_T) -> analytic"
# TODO: call chi to compile it and ensure required properties
_check_chi(chi; states, trajectories, tau, via)
return chi
catch exception
if exception isa MethodError
@info "make_chi for J_T=$(J_T): fallback to mode=:automatic"
try
chi = make_automatic_chi(J_T, trajectories, automatic; via)
# TODO: call chi to compile it and ensure required properties
_check_chi(chi; states, trajectories, tau, via)
return chi
catch exception
if exception isa MethodError
Expand All @@ -228,7 +275,7 @@ function make_chi(
elseif mode == :analytic
try
chi = make_analytic_chi(J_T, trajectories)
# TODO: call chi to compile it and ensure required properties
_check_chi(chi; states, trajectories, tau, via)
return chi
catch exception
if exception isa MethodError
Expand All @@ -241,7 +288,7 @@ function make_chi(
elseif mode == :automatic
try
chi = make_automatic_chi(J_T, trajectories, automatic; via)
# TODO: call chi to compile it and ensure required properties
_check_chi(chi; states, trajectories, tau, via)
return chi
catch exception
if exception isa MethodError
Expand Down
4 changes: 2 additions & 2 deletions test/test_functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,13 @@ end
throw(DomainError("XXX"))
end

@test_throws DomainError begin
@test_throws Exception begin
IOCapture.capture() do
make_chi(J_T_xxx, trajectories)
end
end

@test_throws DomainError begin
@test_throws Exception begin
IOCapture.capture() do
make_chi(J_T_xxx, trajectories; mode=:automatic)
end
Expand Down