Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 117 additions & 4 deletions qfi_opt/examples/calculate_qfi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import numpy as np
from scipy.linalg import solve_sylvester

from qfi_opt import spin_models

Expand All @@ -16,10 +17,83 @@ def compute_eigendecomposition(rho: np.ndarray):
return eigvals, eigvecs


def compute_QFI(eigvals: np.ndarray, eigvecs: np.ndarray, G: np.ndarray, A:
np.ndarray= np.empty(0), dA: np.ndarray= np.empty(0), d2A: np.ndarray = np.empty(0),
grad: np.ndarray = np.empty(0), tol: float = 1e-8, etol_scale: float =
10) -> float:
def compute_QFI_simpler_api(
eigvals: np.ndarray, eigvecs: np.ndarray, A: np.ndarray, params: np.ndarray, grad, get_jacobian, obj_params, tol: float = 1e-8, etol_scale: float = 10
) -> float:
# Note: The eigenvectors must be rows of eigvecs
num_vals = len(eigvals)
num_params = len(params)

G = obj_params["G"]

# There should never be negative eigenvalues, so their magnitude gives an
# empirical estimate of the numerical accuracy of the eigendecomposition.
# We discard any QFI terms denominators within an order of magnitude of
# this value.
tol = max(tol, -etol_scale * np.min(eigvals))

# Compute QFI and grad
running_sum = 0

if grad.size > 0:

dA = get_jacobian(params, obj_params["N"], dissipation_rates=obj_params["dissipation"])
dA = np.transpose(dA, (2, 0, 1))

grad[:] = np.zeros(num_params)
lambda_grads = np.zeros((num_params, num_vals))
psi_grads = np.zeros((num_params, num_vals, num_vals), dtype="cdouble")

for k in range(num_params):
# compute gradients of each eigenvalue
lambda_grad_k, psi_grad_k = get_matrix_grads_sylvester(dA[k], eigvals, eigvecs, tol)
lambda_grads[k] = lambda_grad_k
psi_grads[k] = psi_grad_k

for i in range(num_vals):
for j in range(i + 1, num_vals):
denom = eigvals[i] + eigvals[j]
diff = eigvals[i] - eigvals[j]
if not np.isclose(denom, 0, atol=tol, rtol=tol) and not np.isclose(diff, 0, atol=tol, rtol=tol):
numer = diff**2
term = eigvecs[i].conj() @ G @ eigvecs[j]
quotient = numer / denom
squared_modulus = np.absolute(term) ** 2
running_sum += quotient * squared_modulus
if grad.size > 0:
for k in range(num_params):
# fill in gradient
grad[k] += kth_partial_derivative(
quotient,
squared_modulus,
eigvals[i],
eigvals[j],
lambda_grads[k, i],
lambda_grads[k, j],
eigvecs[i],
eigvecs[j],
psi_grads[k, i],
psi_grads[k, j],
G,
)

if grad.size > 0:
return 4 * running_sum, 4 * grad
else:
return 4 * running_sum, []


def compute_QFI(
eigvals: np.ndarray,
eigvecs: np.ndarray,
G: np.ndarray,
A: np.ndarray = np.empty(0),
dA: np.ndarray = np.empty(0),
d2A: np.ndarray = np.empty(0),
grad: np.ndarray = np.empty(0),
tol: float = 1e-8,
etol_scale: float = 10,
) -> float:
# Note: The eigenvectors must be rows of eigvecs
num_vals = len(eigvals)
num_params = dA.shape[0]
Expand Down Expand Up @@ -87,6 +161,45 @@ def get_matrix_grads_lazy(A, dA, eigvals, eigvecs):
return lambda_grads, psi_grads


def get_matrix_grads_sylvester(dA, eigvals, eigvecs, tol):

dim = eigvecs.shape[0]
lambda_grads = np.zeros(dim, dtype="cdouble")
psi_grads = np.zeros((dim, dim), dtype="cdouble")

# force Hermitianness:
dA = (dA + dA.conj().T) / 2.0

# group the sorted eigvals by tolerance, intended to help stability of eigenvector derivatives:
current_ind = 0
for ind1 in range(dim):
if current_ind == ind1:
for ind2 in range(ind1 + 1, dim):
if not np.isclose(eigvals[ind2], eigvals[ind1], atol=tol, rtol=tol):
break # the for loop over ind2
# we just broke the for loop, so:
current_ind = ind2

# do a sylvester solve:
group_set = np.arange(ind1, ind2)
# special case when ind2=dim (must be something smarter)
if group_set.size == 0:
group_set = [ind2]
not_in_group_set = np.setdiff1d(np.arange(dim), group_set)
A_group = np.diag(eigvals[group_set])
A_not_in_group = np.diag(eigvals[not_in_group_set])
rotation = eigvecs[group_set].conj() @ dA @ eigvecs[not_in_group_set].T
sol = solve_sylvester(A_group, -1.0 * A_not_in_group, rotation)
psi_grads[group_set] = sol @ eigvecs[not_in_group_set].conj()

# average eigenvalue:
multiplicity = len(group_set)
dLambda = (1.0 / multiplicity) * np.trace(eigvecs[group_set].conj() @ dA @ eigvecs[group_set].T)
lambda_grads[group_set] = np.ones(multiplicity) * dLambda

return np.real(lambda_grads), psi_grads.conj()


def get_matrix_grads(A, dA, d2A, eigvals, eigvecs, tol):
num_vals = len(eigvals)
num_params = dA.shape[0]
Expand Down
56 changes: 56 additions & 0 deletions qfi_opt/examples/matt_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np

import qfi_opt
from qfi_opt import spin_models
from qfi_opt.examples.calculate_qfi import compute_eigendecomposition, compute_QFI_simpler_api

np.set_printoptions(precision=4, linewidth=200)


def sim_wrapper_diffrax(x, qfi_grad, obj, obj_params, get_jacobian):
rho = obj(x, obj_params["N"], dissipation_rates=obj_params["dissipation"])

# force Hermitianness:
rho = (rho + rho.conj().T) / 2.0
# Compute eigendecomposition
vals, vecs = compute_eigendecomposition(rho)

qfi, new_grad = compute_QFI_simpler_api(vals, vecs, rho, x, qfi_grad, get_jacobian, obj_params)
print(x, qfi, new_grad, flush=True)

try:
if qfi_grad.size > 0:
qfi_grad[:] = -1.0 * new_grad
except:
qfi_grad[:] = []

return -1.0 * qfi # , -qfi_grad # negative because we are maximizing


if __name__ == "__main__":

N = 4
G = spin_models.collective_op(spin_models.PAULI_Z, N) / (2 * N)

obj_params = {}
obj_params["N"] = N
obj_params["dissipation"] = 1.0
obj_params["G"] = G

seed = 88
np.random.seed(seed)

num_params = 5

lb = np.zeros(num_params)
ub = np.ones(num_params)

x0 = np.random.uniform(lb, ub, num_params)
model = "simulate_TAT"

obj = getattr(spin_models, model)

grad = np.zeros(num_params)
get_jacobian = spin_models.get_jacobian_func(obj)
out = sim_wrapper_diffrax(x0, grad, obj, obj_params, get_jacobian)
print(out)