import warnings
import numpy as np
import psutil
from typing import Sequence, Callable, Literal
from ..sensor_utils import TimeFunc
from ..exceptions import RydiquleError
[docs]
def cyrk_solve(eoms_base: np.ndarray, const_base: np.ndarray,
eom_time_r: np.ndarray, const_r: np.ndarray,
eom_time_i: np.ndarray, const_i: np.ndarray,
time_inputs: Sequence[TimeFunc],
t_eval: np.ndarray, init_cond: np.ndarray,
eqns: Literal["orig", "flat"] = "orig",
**kwargs
) -> np.ndarray:
"""
Solve a set of Optical Bloch Equations (OBEs) with rydiqule's time solving convention
using CyRK's `pysolve_ivp`.
Uses matrix components of the equations of motion provided by the methods of a :meth:`~.Sensor`.
Designed to be used as a wrapped function within :func:`~.timesolvers.solve_time`.
Builds and solves equations of motion according rydiqule's time solving conventions.
Sets up and solves dx/dt = A(t)x + b(t)
For larger solve systems, `max_ram_MB` kwarg for `pysolve_ivp` will likely need to be increased
from its default of 2000.
Args
----
eoms_base: numpy.ndarray
The matrix of shape `(*l,n,n)` representing the non time-varying portion of the matrix A
in the equations of motion.
const_base: numpy.ndarray
The array of shape `(*l, n)` representing the non time-varying portion of the vector b in the
equations of motion.
eoms_time_r: numpy.ndarray
The matrix of shape `(n_t, *l, n, n)` representing the real time-varying portion of the matrix A,
where n_t is the length of `time_inputs`.
The ith slice along the first axis should be multiplied by the real part
of the ith entry in `time_inputs`.
const_r: numpy.nd_array
The matrix of shape `(n_t, *l, n)` representing the real time-varying portion of the vector b,
where n_t is the length of `time_inputs`.
The ith slice along the first axis should be multiplied by the real part
of the ith entry in `time_inputs`.
eoms_time_i: numpy.ndarray
The matrix of shape `(n_t, *l, n, n)` representing the imaginary time-varying portion of the matrix A,
where n_t is the length of `time_inputs`.
The ith slice along the first axis should be multiplied by the imaginary part
of the ith entry in `time_inputs`.
const_i: numpy.nd_array
The matrix of shape `(n_t, *l, n)` representing the imaginary time-varying portion of the vector b,
where n_t is the length of `time_inputs`.
The ith slice along the first axis should be multiplied by the imaginary part
of the ith entry in `time_inputs`.
time_inputs: list(callable)
List of callable functions of length `n_t`.
The functions should take a single floating point
as an input representing the time in microseconds,
and return a real or complex floating point value represent an
electric field in V/m at that time.
Return type of each function must be the same for all inputs t.
t_eval: numpy.ndarray
Array of times to sample the integration at.
This array must have dtype of float64.
init_cond: numpy.ndarray
Matrix of shape `(*l, n)` representing the initial state of the system.
eqns: {"orig", "flat"}
Method used to generate the derivative equations.
Options are orig (which uses a numpy reshaping approach)
and flat (which uses flat array indexing).
**kwargs: dict:
Additional keyword arguments passed to `pysolve_ivp`.
Returns
-------
numpy.ndarray
The matrix solution of shape `(*l,n,n_t)`
representing the density matrix of the system at each time t.
Raises
------
RydiquleError
If system size exceeds cyrk backend limit of 65535 equations.
If we see this error a lot, consider getting CyRK project to increase it
by changing type of `y_size` from unsigned short.
"""
try:
import numba as nb
from CyRK import pysolve_ivp
except ImportError as e:
raise RydiquleError('CyRK backend not installed') from e
try:
fns = _eqnsGen[eqns]
except KeyError as err:
raise RydiquleError("\'eqns\' must be one of \'orig\' or \'flat\'") from err
to_compile = [not nb.extending.is_jitted(f) for f in time_inputs]
complex_out = [isinstance(f(0.0), complex) for f in time_inputs]
time_inputs_compiled = tuple(nb.njit("c16(f8)", cache=True)(f) if t and c
else nb.njit("f8(f8)", cache=True)(f)
if c else f
for c,t,f in zip(to_compile,complex_out,time_inputs))
with warnings.catch_warnings():
# ignore first-class function warning
warnings.simplefilter("ignore", category=nb.NumbaExperimentalFeatureWarning)
equations = fns(eoms_base, const_base,
eom_time_r, const_r,
eom_time_i, const_i,
time_inputs_compiled
)
# enforce default arguments consistent with scipy solver
method = kwargs.pop("method", "RK45")
rtol = kwargs.pop("rtol", 1e-6)
max_ram_MB = kwargs.pop("max_ram_MB", max(psutil.virtual_memory().available/(1024**2)/10, 2_000))
result = pysolve_ivp(equations, (t_eval[0], t_eval[-1]),
init_cond.ravel(),
t_eval=t_eval,
method=method, rtol=rtol,
max_ram_MB=max_ram_MB,
pass_dy_as_arg=True,
**kwargs)
if not result.success:
result.print_diagnostics()
raise RydiquleError(f"Integration failed ({result.error_code}): {result.message}")
sol_shape = eoms_base.shape[:-1]
return result.y.reshape(sol_shape + (result.y.shape[-1],))
def _derEqns(obes_base: np.ndarray, const_base: np.ndarray,
obes_time_r: np.ndarray, const_r: np.ndarray,
obes_time_i: np.ndarray, const_i: np.ndarray,
time_inputs: Sequence[TimeFunc]
) -> Callable[[np.ndarray, float, np.ndarray], None]:
"""
Function to build the callable passed to CyRK's pysolve_ivp cython solver.
Note that `time_inputs` functions must be njit compiled.
Uses the base and time matrix components of the eoms to build
a function of vector and scalar time
that has the expected input/output of functions passed to `cyrk.pysolve_ivp()`
"""
import numba as nb
t_func_num = obes_time_r.shape[0]
input_shape = obes_base.shape[:-1]
stack_shape = obes_base.shape[:-2]
@nb.njit("void(f8[::1], f8, f8[::1])")
def func(result_out: np.ndarray, t: float, A_flat: np.ndarray):
# create OBEs at time t
obe_total = obes_base.copy()
const_total = const_base.copy()
for idx in range(t_func_num):
ti = time_inputs[idx](t)
obe_total += ti.real*obes_time_r[idx] + ti.imag*obes_time_i[idx]
const_total += ti.real*const_r[idx] + ti.imag*const_i[idx]
# reshape input to stack shape
A_stack = A_flat.reshape(input_shape)
result = np.empty_like(A_stack)
# matrix multiply obes with input for each parameter of stack
for sidx in np.ndindex(stack_shape):
result[sidx] = np.dot(obe_total[sidx], A_stack[sidx])
# add const values, note: uses broadcasting to handle doppler axis
result += const_total
# load stacked result into flat output array
for i, v in enumerate(result.flat):
result_out[i] = v
return func
def _derEqns_flat(obes_base: np.ndarray, const_base: np.ndarray,
obes_time_r: np.ndarray, const_r: np.ndarray,
obes_time_i: np.ndarray, const_i: np.ndarray,
time_inputs: Sequence[TimeFunc]
) -> Callable[[np.ndarray, float, np.ndarray], None]:
"""
Function to build the callable passed to CyRK's pysolve_ivp cython solver.
Note that `time_inputs` functions must be njit compiled.
Uses the base and time matrix components of the eoms to build
a function of vector and scalar time
that has the expected input/output of functions passed to `cyrk.pysolve_ivp()`
This implementation is explicitly flat and avoids extra array allocations.
"""
import numba as nb
if obes_base.shape[:-1] != const_base.shape:
raise RydiquleError("CyRK flat solver incompatible with doppler solves")
# basis dimension size
b = obes_base.shape[-1]
# time function dimension size
t_func_num = obes_time_r.shape[0]
# shapes to broadcast time-dependent parts to before flattening
obes_time_shape = (t_func_num, ) + obes_base.shape
const_time_shape = obes_time_shape[:-1]
# flatten eqns arrays
obes_base = obes_base.reshape(-1)
const_base = const_base.reshape(-1)
# broadcasts ensure that time-dependent parts trivially match base
# this ensures we don't have to implement arbitrary broadcasting rules below
# note that these arrays are not guaranteed to be C-contiguous
# with may have performance implications
obes_time_r_broad = np.broadcast_to(obes_time_r, obes_time_shape)
obes_time_i_broad = np.broadcast_to(obes_time_i, obes_time_shape)
const_r_broad = np.broadcast_to(const_r, const_time_shape)
const_i_broad = np.broadcast_to(const_i, const_time_shape)
# we now flatten to 2D arrays
obes_time_r = obes_time_r_broad.reshape((t_func_num, -1))
obes_time_i = obes_time_i_broad.reshape((t_func_num, -1))
const_r = const_r_broad.reshape((t_func_num, -1))
const_i = const_i_broad.reshape((t_func_num, -1))
@nb.njit("void(f8[::1], f8, f8[::1])")
def func(result_out: np.ndarray, t: float, A_flat: np.ndarray):
# calculate time inputs at time t
ts = np.zeros(t_func_num, dtype=np.complex128)
for idx in range(t_func_num):
ts[idx] = time_inputs[idx](t)
for i in range(result_out.size):
# start result with time-independent constant part
result_out[i] = const_base[i]
# define idx for this loop separately
const_time_idx = i%b
for idx in range(t_func_num):
# add time-dependent const part
result_out[i] += ts[idx].real*const_r[idx, const_time_idx] + ts[idx].imag*const_i[idx, const_time_idx]
for j in range(b):
# define indices for this step
obe_idx = i*b+j
A_idx = (i//b)*b+j
# add time-independent obe part
# implements einsum('...ij,...j', obes, A)
result_out[i] += obes_base[obe_idx] * A_flat[A_idx]
for idx in range(t_func_num):
# add time-dependent obe part
result_out[i] += (ts[idx].real*obes_time_r[idx, obe_idx]
+ ts[idx].imag*obes_time_i[idx, obe_idx]) * A_flat[A_idx]
return func
_eqnsGen = {"orig": _derEqns, "flat": _derEqns_flat}