Exact Discretization¶
Warning
This page is still under development .
dVdt.py
def dVdt(self,
tau: float,
V: jnp.ndarray,
u_cur: np.ndarray,
u_next: np.ndarray
) -> jnp.ndarray:
"""
Computes the time derivative of the augmented state vector for the system for a sequence of states.
Parameters:
tau (float): Current time.
V (np.ndarray): Sequence of augmented state vectors.
u_cur (np.ndarray): Sequence of current control inputs.
u_next (np.ndarray): Sequence of next control inputs.
A: Function that computes the Jacobian of the system dynamics with respect to the state.
B: Function that computes the Jacobian of the system dynamics with respect to the control input.
obstacles: List of obstacles in the environment.
params (dict): Parameters of the system.
Returns:
np.ndarray: Time derivatives of the augmented state vectors.
"""
# Extract the number of states and controls from the parameters
n_x = self.params.sim.n_states
n_u = self.params.sim.n_controls
# Unflatten V
V = V.reshape(-1, self.i5)
# Compute the interpolation factor based on the discretization type
if self.params.dis.dis_type == 'ZOH':
beta = 0.
elif self.params.dis.dis_type == 'FOH':
beta = (tau) * self.params.scp.n
alpha = 1 - beta
# Interpolate the control input
u = u_cur + beta * (u_next - u_cur)
s = u[:,-1]
# Initialize the augmented Jacobians
dfdx = jnp.zeros((V.shape[0], n_x, n_x))
dfdu = jnp.zeros((V.shape[0], n_x, n_u))
# Ensure x_seq and u have the same batch size
x = V[:,:self.params.sim.n_states]
u = u[:x.shape[0]]
# Compute the nonlinear propagation term
f = self.params.dyn.state_dot(x, u[:,:-1])
F = s[:, None] * f
# Evaluate the State Jacobian
dfdx = self.params.dyn.A(x, u[:,:-1])
sdfdx = s[:, None, None] * dfdx
# Evaluate the Control Jacobian
dfdu_veh = self.params.dyn.B(x, u[:,:-1])
dfdu = dfdu.at[:, :, :-1].set(s[:, None, None] * dfdu_veh)
dfdu = dfdu.at[:, :, -1].set(f)
# Compute the defect
z = F - jnp.einsum('ijk,ik->ij', sdfdx, x) - jnp.einsum('ijk,ik->ij', dfdu, u)
# Stack up the results into the augmented state vector
dVdt = jnp.zeros_like(V)
dVdt = dVdt.at[:, self.i0:self.i1].set(F)
dVdt = dVdt.at[:, self.i1:self.i2].set(jnp.matmul(sdfdx, V[:, self.i1:self.i2].reshape(-1, n_x, n_x)).reshape(-1, n_x * n_x))
dVdt = dVdt.at[:, self.i2:self.i3].set((jnp.matmul(sdfdx, V[:, self.i2:self.i3].reshape(-1, n_x, n_u)) + dfdu * alpha).reshape(-1, n_x * n_u))
dVdt = dVdt.at[:, self.i3:self.i4].set((jnp.matmul(sdfdx, V[:, self.i3:self.i4].reshape(-1, n_x, n_u)) + dfdu * beta).reshape(-1, n_x * n_u))
dVdt = dVdt.at[:, self.i4:self.i5].set((jnp.matmul(sdfdx, V[:, self.i4:self.i5].reshape(-1, n_x)[..., None]).squeeze(-1) + z).reshape(-1, n_x))
return dVdt.flatten()