integrators
Numerical integration schemes for trajectory optimization.
This module provides implementations of numerical integrators used for simulating continuous-time dynamics.
Current Implementations
RK45 Integration: Explicit Runge-Kutta-Fehlberg method (4th/5th order) with both fixed-step and adaptive implementations via Diffrax. Supports a variety of explicit and implicit ODE solvers through the Diffrax backend (Dopri5/8, Tsit5, KenCarp3/4/5, etc.).
Planned Architecture (ABC-based):
A base class will be introduced to enable pluggable integrator implementations. This will enable users to implement custom integrators. Future integrators will implement the Integrator interface:
# integrators/base.py (planned):
class Integrator(ABC):
@abstractmethod
def step(self, f: Callable, x: Array, u: Array, t: float, dt: float) -> Array:
'''Take one integration step from state x at time t with step dt.'''
...
@abstractmethod
def integrate(self, f: Callable, x0: Array, u_traj: Array,
t_span: tuple[float, float], num_steps: int) -> Array:
'''Integrate over a time span with given control trajectory.'''
...
rk45_step(f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], t: jnp.ndarray, y: jnp.ndarray, h: float, *args) -> jnp.ndarray
¶
Perform a single RK45 (Runge-Kutta-Fehlberg) integration step.
This implements the classic Dorman-Prince coefficients for an explicit 4(5) method, returning the fourth-order estimate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[ndarray, ndarray, Any], ndarray]
|
ODE right-hand side; signature f(t, y, *args) -> dy/dt. |
required |
t
|
ndarray
|
Current time. |
required |
y
|
ndarray
|
Current state vector. |
required |
h
|
float
|
Step size. |
required |
*args
|
Additional arguments passed to |
()
|
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Next state estimate at t + h. |
Source code in openscvx/integrators/runge_kutta.py
solve_ivp_diffrax(f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], tau_final: float, y_0: jnp.ndarray, args, tau_0: float = 0.0, num_substeps: int = 50, solver_name: str = 'Dopri8', rtol: float = 0.001, atol: float = 1e-06, extra_kwargs=None)
¶
Solve an initial-value ODE problem using a Diffrax adaptive solver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[ndarray, ndarray, Any], ndarray]
|
ODE right-hand side; f(t, y, *args). |
required |
tau_final
|
float
|
Final integration time. |
required |
y_0
|
ndarray
|
Initial state at tau_0. |
required |
args
|
tuple
|
Extra arguments to pass to |
required |
tau_0
|
float
|
Initial time. Defaults to 0.0. |
0.0
|
num_substeps
|
int
|
Number of save points between tau_0 and tau_final. Defaults to 50. |
50
|
solver_name
|
str
|
Key into SOLVER_MAP for the Diffrax solver class. Defaults to "Dopri8". |
'Dopri8'
|
rtol
|
float
|
Relative tolerance for adaptive stepping. Defaults to 1e-3. |
0.001
|
atol
|
float
|
Absolute tolerance for adaptive stepping. Defaults to 1e-6. |
1e-06
|
extra_kwargs
|
dict
|
Additional keyword arguments forwarded to |
None
|
Returns:
| Type | Description |
|---|---|
|
jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in openscvx/integrators/runge_kutta.py
solve_ivp_diffrax_prop(f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], tau_final: float, y_0: jnp.ndarray, args, tau_0: float = 0.0, num_substeps: int = 50, solver_name: str = 'Dopri8', rtol: float = 0.001, atol: float = 1e-06, extra_kwargs=None, save_time: jnp.ndarray = None, mask: jnp.ndarray = None)
¶
Solve an initial-value ODE problem using a Diffrax adaptive solver. This function is specifically designed for use in the context of trajectory optimization and handles the nonlinear single-shot propagation of state variables in undilated time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[ndarray, ndarray, Any], ndarray]
|
ODE right-hand side; signature f(t, y, *args) -> dy/dt. |
required |
tau_final
|
float
|
Final integration time. |
required |
y_0
|
ndarray
|
Initial state at tau_0. |
required |
args
|
tuple
|
Extra arguments to pass to |
required |
tau_0
|
float
|
Initial time. Defaults to 0.0. |
0.0
|
num_substeps
|
int
|
Number of save points between tau_0 and tau_final. Defaults to 50. |
50
|
solver_name
|
str
|
Key into SOLVER_MAP for the Diffrax solver class. Defaults to "Dopri8". |
'Dopri8'
|
rtol
|
float
|
Relative tolerance for adaptive stepping. Defaults to 1e-3. |
0.001
|
atol
|
float
|
Absolute tolerance for adaptive stepping. Defaults to 1e-6. |
1e-06
|
extra_kwargs
|
dict
|
Additional keyword arguments forwarded to |
None
|
save_time
|
ndarray
|
Time points at which to evaluate the solution. Must be provided for export compatibility. |
None
|
mask
|
ndarray
|
Boolean mask for the save_time points. |
None
|
Returns:
| Type | Description |
|---|---|
|
jnp.ndarray: Solution states at the requested save points, shape (num_substeps, state_dim). |
Raises:
ValueError: If solver_name is not in SOLVER_MAP or if save_time is not provided.
Source code in openscvx/integrators/runge_kutta.py
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | |
solve_ivp_rk45(f: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], tau_final: float, y_0: jnp.ndarray, args, tau_0: float = 0.0, num_substeps: int = 50, is_not_compiled: bool = False)
¶
Solve an initial-value ODE problem using fixed-step RK45 integration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Callable[[ndarray, ndarray, Any], ndarray]
|
ODE right-hand side; signature f(t, y, *args) -> dy/dt. |
required |
tau_final
|
float
|
Final integration time. |
required |
y_0
|
ndarray
|
Initial state at tau_0. |
required |
args
|
tuple
|
Extra arguments to pass to |
required |
tau_0
|
float
|
Initial time. Defaults to 0.0. |
0.0
|
num_substeps
|
int
|
Number of output time points. Defaults to 50. |
50
|
is_not_compiled
|
bool
|
If True, use Python loop instead of
JAX |
False
|
Returns:
| Type | Description |
|---|---|
|
jnp.ndarray: Array of shape (num_substeps, state_dim) with solution at each time. |