def propagate_trajectory_results(
params: dict, settings: Config, result: OptimizationResults, propagation_solver: callable
) -> OptimizationResults:
"""Propagate the optimal trajectory and compute additional results.
This function takes the optimal control solution and propagates it through the
nonlinear dynamics to compute the actual state trajectory and other metrics.
Args:
params (dict): System parameters.
settings (Config): Configuration settings.
result (OptimizationResults): Optimization results object.
propagation_solver (callable): Function for propagating the system state.
Returns:
OptimizationResults: Updated results object containing:
- t_full: Full time vector
- x_full: Full state trajectory
- u_full: Full control trajectory
- cost: Computed cost
- ctcs_violation: CTCS constraint violation
"""
# Get arrays from result
x = result.x
u = result.u
t = np.array(s_to_t(x, u, settings)).squeeze()
t_full = np.arange(t[0], t[-1], settings.prp.dt)
tau_vals, u_full = t_to_tau(u, t_full, t, settings)
# Create a copy of x_prop for propagation to avoid mutating settings
# Match free values from initial state to the initial value from the result
x_prop_for_propagation = copy.copy(settings.sim.x_prop)
# Only copy for states that exist in optimization (propagation may have extra states at the end)
n_opt_states = x.shape[1]
n_prop_states = settings.sim.x_prop.initial.shape[0]
if n_opt_states == n_prop_states:
# Same size - copy all
# Use metadata from settings (immutable configuration)
mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
x_prop_for_propagation.initial = jnp.where(mask, x[0, :], settings.sim.x_prop.initial)
else:
# Propagation has extra states - only copy the overlapping portion
# Use metadata from settings (immutable configuration)
mask = jnp.array([t == "Free" for t in settings.sim.x.initial_type], dtype=bool)
x_prop_initial_updated = settings.sim.x_prop.initial.copy()
x_prop_initial_updated[:n_opt_states] = jnp.where(
mask, x[0, :], settings.sim.x_prop.initial[:n_opt_states]
)
x_prop_for_propagation.initial = x_prop_initial_updated
# Temporarily replace x_prop with our modified copy for propagation
# Save original to restore after propagation
original_x_prop = settings.sim.x_prop
settings.sim.x_prop = x_prop_for_propagation
try:
x_full = simulate_nonlinear_time(params, x, u, tau_vals, t, settings, propagation_solver)
finally:
# Always restore original x_prop, even if propagation fails
settings.sim.x_prop = original_x_prop
# Calculate cost using utility function and metadata from settings
cost = calculate_cost_from_boundaries(x, settings.sim.x.initial_type, settings.sim.x.final_type)
# Calculate CTCS constraint violation
ctcs_violation = x_full[-1, settings.sim.ctcs_slice_prop]
# Build trajectory dictionary with all states and controls
trajectory_dict = {}
# Add all states (user-defined and augmented)
for state in result._states:
trajectory_dict[state.name] = x_full[:, state._slice]
# Add all controls (user-defined and augmented)
for control in result._controls:
trajectory_dict[control.name] = u_full[:, control._slice]
# Update the results object with post-processing data
result.t_full = t_full
result.x_full = x_full
result.u_full = u_full
result.cost = cost
result.ctcs_violation = ctcs_violation
result.trajectory = trajectory_dict
return result