Obstacle Avoidance Realtime¶
Interactive real-time visualization for drone obstacle avoidance.
This module provides a PyQt5-based GUI for interactively solving and visualizing the drone obstacle avoidance trajectory optimization problem in real-time.
File: examples/realtime/obstacle_avoidance_realtime.py
import os
import sys
import threading
import time
import numpy as np
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtWidgets import (
QApplication,
QGroupBox,
QHBoxLayout,
QLabel,
QLineEdit,
QPushButton,
QSlider,
QVBoxLayout,
QWidget,
)
current_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(grandparent_dir)
from examples.drone.obstacle_avoidance_realtime_base import (
obstacle_centers,
plotting_dict,
problem,
)
# Import PyQtGraph OpenGL modules
try:
from pyqtgraph.opengl import (
GLGridItem,
GLMeshItem,
GLScatterPlotItem,
GLViewWidget,
MeshData,
)
HAS_OPENGL = True
except ImportError:
print("PyQtGraph OpenGL not available, falling back to 2D")
HAS_OPENGL = False
running = {"stop": False}
reset_requested = {"reset": False}
latest_results = {"results": None}
new_result_event = threading.Event()
class Obstacle3DPlotWidget(QWidget):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Use axes and radii from plotting_dict (from obstacle_avoidance.py)
self.ellipsoid_axes = plotting_dict["obstacles_axes"]
self.ellipsoid_radii = plotting_dict["obstacles_radii"]
layout = QVBoxLayout()
self.setLayout(layout)
if HAS_OPENGL:
# Create 3D view
self.view = GLViewWidget()
self.view.setCameraPosition(distance=15)
# Add grid
grid = GLGridItem()
self.view.addItem(grid)
# Add trajectory scatter plot
self.traj_scatter = GLScatterPlotItem(pos=np.zeros((1, 3)), color=(0, 0, 1, 1), size=5)
self.view.addItem(self.traj_scatter)
# Create main layout with view and control panel
main_layout = QHBoxLayout()
# Create control panel
self.create_control_panel()
# Create obstacle ellipsoids
self.obs_ellipsoids = []
self.create_obstacle_ellipsoids()
# Add widgets to main layout
main_layout.addWidget(self.view, stretch=3)
main_layout.addWidget(self.control_panel, stretch=1)
layout.addLayout(main_layout)
else:
# Fallback to 2D
label = QLabel("3D OpenGL not available")
layout.addWidget(label)
def create_control_panel(self):
"""Create the control panel with sliders for each obstacle"""
self.control_panel = QWidget()
control_layout = QVBoxLayout()
self.control_panel.setLayout(control_layout)
# Title
title = QLabel("3D Obstacle Avoidance Control")
title.setStyleSheet("font-weight: bold; font-size: 14px;")
control_layout.addWidget(title)
# Optimization Metrics Display
metrics_group = QGroupBox("Optimization Metrics")
metrics_layout = QVBoxLayout()
metrics_group.setLayout(metrics_layout)
# Create labels for each metric
self.iter_label = QLabel("Iteration: 0")
self.j_tr_label = QLabel("J_tr: 0.00e+00")
self.j_vb_label = QLabel("J_vb: 0.00e+00")
self.j_vc_label = QLabel("J_vc: 0.00e+00")
self.objective_label = QLabel("Objective: 0.00e+00")
self.lam_cost_display_label = QLabel(f"λ_cost: {problem.settings.scp.lam_cost:.2E}")
self.dis_time_label = QLabel("Dis Time: 0.0ms")
self.solve_time_label = QLabel("Solve Time: 0.0ms")
self.status_label = QLabel("Status: --")
# Style the labels
for label in [
self.iter_label,
self.j_tr_label,
self.j_vb_label,
self.j_vc_label,
self.objective_label,
self.lam_cost_display_label,
self.dis_time_label,
self.solve_time_label,
self.status_label,
]:
label.setStyleSheet("font-family: monospace; font-size: 11px; padding: 2px;")
metrics_layout.addWidget(label)
control_layout.addWidget(metrics_group)
# Optimization Weights
weights_group = QGroupBox("Optimization Weights")
weights_layout = QVBoxLayout()
weights_group.setLayout(weights_layout)
# Lambda cost input - Input on left, label on right
lam_cost_layout = QHBoxLayout()
lam_cost_input = QLineEdit()
lam_cost_input.setText(f"{problem.settings.scp.lam_cost:.2E}")
lam_cost_input.setFixedWidth(80)
lam_cost_input.returnPressed.connect(lambda: on_lam_cost_changed(lam_cost_input))
lam_cost_label = QLabel("λ_cost:")
lam_cost_label.setAlignment(Qt.AlignLeft)
lam_cost_layout.addWidget(lam_cost_input)
lam_cost_layout.addWidget(lam_cost_label)
lam_cost_layout.addStretch() # Push everything to the left
weights_layout.addLayout(lam_cost_layout)
# Lambda trust region input - Input on left, label on right
lam_tr_layout = QHBoxLayout()
lam_tr_input = QLineEdit()
lam_tr_input.setText(f"{problem.settings.scp.w_tr:.2E}")
lam_tr_input.setFixedWidth(80)
lam_tr_input.returnPressed.connect(lambda: on_lam_tr_changed(lam_tr_input))
lam_tr_label = QLabel("λ_tr:")
lam_tr_label.setAlignment(Qt.AlignLeft)
lam_tr_layout.addWidget(lam_tr_input)
lam_tr_layout.addWidget(lam_tr_label)
lam_tr_layout.addStretch() # Push everything to the left
weights_layout.addLayout(lam_tr_layout)
control_layout.addWidget(weights_group)
# Problem Control
problem_control_group = QGroupBox("Problem Control")
problem_control_layout = QVBoxLayout()
problem_control_group.setLayout(problem_control_layout)
reset_problem_button = QPushButton("Reset Problem")
reset_problem_button.clicked.connect(self.on_reset_clicked)
problem_control_layout.addWidget(reset_problem_button)
control_layout.addWidget(problem_control_group)
# Sliders for each obstacle
for i in range(3):
obs_group = QGroupBox(f"Obstacle {i + 1} Position")
obs_layout = QVBoxLayout()
obs_group.setLayout(obs_layout)
# X, Y, Z sliders
sliders = []
for j, coord in enumerate(["X", "Y", "Z"]):
slider_layout = QHBoxLayout()
label = QLabel(f"{coord}:")
slider = QSlider(Qt.Horizontal)
slider.setRange(-100, 100)
slider.setValue(0)
value_label = QLabel("0.00")
# Connect slider to update function
slider.valueChanged.connect(
lambda val, obs=i, axis=j, label=value_label: self.on_slider_changed(
val, obs, axis, label
)
)
slider_layout.addWidget(label)
slider_layout.addWidget(slider)
slider_layout.addWidget(value_label)
obs_layout.addLayout(slider_layout)
sliders.append((slider, value_label))
# Store sliders for this obstacle
setattr(self, f"obs_{i}_sliders", sliders)
control_layout.addWidget(obs_group)
control_layout.addStretch()
# Create labels dictionary for metrics update
self.labels_dict = {
"iter_label": self.iter_label,
"j_tr_label": self.j_tr_label,
"j_vb_label": self.j_vb_label,
"j_vc_label": self.j_vc_label,
"objective_label": self.objective_label,
"lam_cost_display_label": self.lam_cost_display_label,
"dis_time_label": self.dis_time_label,
"solve_time_label": self.solve_time_label,
"status_label": self.status_label,
}
def create_obstacle_ellipsoids(self):
if not HAS_OPENGL:
return
for _i, (ax, rad) in enumerate(zip(self.ellipsoid_axes, self.ellipsoid_radii)):
# Create main ellipsoid
mesh = MeshData.sphere(rows=20, cols=20, radius=1.0)
verts = mesh.vertexes()
verts = verts * 1 / (rad) # scale to ellipsoid
verts = verts @ ax.T # rotate by axes
mesh.setVertexes(verts)
ellipsoid = GLMeshItem(
meshdata=mesh,
color=(0, 1, 0, 0.3), # RGBA, green, transparent
shader="shaded",
smooth=True,
)
ellipsoid.setGLOptions("translucent") # Enable transparency
# Set initial position using translate
ellipsoid.translate(0, 0, 0)
self.obs_ellipsoids.append(ellipsoid)
self.view.addItem(ellipsoid)
def on_slider_changed(self, value, obstacle_idx, axis, label):
"""Handle slider value changes"""
# Convert slider value (-100 to 100) to world coordinates (-5 to 5)
world_value = value * 0.05
# Update the parameter
param_name = f"obstacle_center_{obstacle_idx + 1}"
center = problem.parameters[param_name].copy()
center[axis] = world_value
# Update both the Parameter object's value and problem.parameters
obstacle_centers[obstacle_idx].value = center
problem.parameters[param_name] = center
# Update visualization
self.update_obstacle_position(obstacle_idx)
# Update label
label.setText(f"{world_value:.2f}")
def update_obstacle_position(self, obstacle_idx):
"""Update obstacle position in 3D view"""
if not HAS_OPENGL:
return
param_name = f"obstacle_center_{obstacle_idx + 1}"
center = problem.parameters[param_name]
# Update ellipsoid position
ellipsoid = self.obs_ellipsoids[obstacle_idx]
ellipsoid.resetTransform()
ellipsoid.translate(center[0], center[1], center[2])
def update_slider_values(self, obstacle_idx):
"""Update slider values to match current obstacle position"""
param_name = f"obstacle_center_{obstacle_idx + 1}"
center = problem.parameters[param_name]
sliders = getattr(self, f"obs_{obstacle_idx}_sliders")
for i, (slider, label) in enumerate(sliders):
# Convert world coordinates to slider values
slider_value = int(center[i] / 0.05)
slider.setValue(slider_value)
label.setText(f"{center[i]:.2f}")
def on_reset_clicked(self):
"""Handle reset button click"""
reset_requested["reset"] = True
print("Problem reset requested")
def keyPressEvent(self, event):
"""Handle keyboard shortcuts"""
if event.key() == Qt.Key_Escape:
self.close()
else:
super().keyPressEvent(event)
def on_lam_cost_changed(input_widget):
"""Handle lambda cost input changes"""
# Extract the new value from the input widget
new_value = input_widget.text()
try:
# Convert the new value to a float
lam_cost_value = float(new_value)
problem.settings.scp.lam_cost = lam_cost_value
# Update the display with scientific notation
input_widget.setText(f"{lam_cost_value:.2E}")
except ValueError:
print("Invalid input. Please enter a valid number.")
def on_lam_tr_changed(input_widget):
"""Handle lambda trust region input changes"""
# Extract the new value from the input widget
new_value = input_widget.text()
try:
# Convert the new value to a float
lam_tr_value = float(new_value)
problem.settings.scp.w_tr = lam_tr_value
# Update the display with scientific notation
input_widget.setText(f"{lam_tr_value:.2E}")
except ValueError:
print("Invalid input. Please enter a valid number.")
def update_optimization_metrics(results, labels_dict):
"""Update the optimization metrics display"""
if results is None:
return
# Extract metrics from results
iter_num = results.get("iter", 0)
j_tr = results.get("J_tr", 0.0)
j_vb = results.get("J_vb", 0.0)
j_vc = results.get("J_vc", 0.0)
cost = results.get("cost", 0.0)
status = results.get("prob_stat", "--")
# Get timing information (these would need to be tracked separately)
dis_time = results.get("dis_time", 0.0)
solve_time = results.get("solve_time", 0.0)
# Update labels
labels_dict["iter_label"].setText(f"Iteration: {iter_num}")
labels_dict["j_tr_label"].setText(f"J_tr: {j_tr:.2E}")
labels_dict["j_vb_label"].setText(f"J_vb: {j_vb:.2E}")
labels_dict["j_vc_label"].setText(f"J_vc: {j_vc:.2E}")
labels_dict["objective_label"].setText(f"Objective: {cost:.2E}")
labels_dict["lam_cost_display_label"].setText(f"λ_cost: {problem.settings.scp.lam_cost:.2E}")
labels_dict["dis_time_label"].setText(f"Dis Time: {dis_time:.1f}ms")
labels_dict["solve_time_label"].setText(f"Solve Time: {solve_time:.1f}ms")
labels_dict["status_label"].setText(f"Status: {status}")
def optimization_loop():
problem.initialize()
try:
while not running["stop"]:
# Check if reset was requested
if reset_requested["reset"]:
problem.reset()
reset_requested["reset"] = False
print("Problem reset to initial conditions")
# Perform a single SCP step (automatically warm-starts from previous iteration)
step_result = problem.step()
# Build results dict for visualization
results = {
"iter": step_result["scp_k"] - 1, # Display iteration (0-indexed)
"J_tr": step_result["scp_J_tr"],
"J_vb": step_result["scp_J_vb"],
"J_vc": step_result["scp_J_vc"],
"converged": step_result["converged"],
"V_multi_shoot": problem.state.V_history[-1] if problem.state.V_history else [],
"x": problem.state.x, # Current state trajectory
"u": problem.state.u, # Current control trajectory
}
# Get timing from the print queue (emitted data)
try:
if hasattr(problem, "print_queue") and not problem.print_queue.empty():
# Get the latest emitted data
emitted_data = problem.print_queue.get_nowait()
results["dis_time"] = emitted_data.get("dis_time", 0.0)
results["solve_time"] = emitted_data.get("subprop_time", 0.0)
results["prob_stat"] = emitted_data.get("prob_stat", "--")
results["cost"] = emitted_data.get("cost", 0.0)
else:
results["dis_time"] = 0.0
results["solve_time"] = 0.0
results["prob_stat"] = "--"
results["cost"] = 0.0
except Exception:
results["dis_time"] = 0.0
results["solve_time"] = 0.0
results["prob_stat"] = "--"
results["cost"] = 0.0
results.update(plotting_dict)
latest_results["results"] = results
new_result_event.set()
except KeyboardInterrupt:
running["stop"] = True
print("Stopped by user.")
def plot_thread_func():
# Initialize PyQtGraph
app = QApplication.instance()
if app is None:
app = QApplication([])
print(f"Creating plot window... OpenGL available: {HAS_OPENGL}")
# Create 3D plot window
plot_widget = Obstacle3DPlotWidget()
plot_widget.setWindowTitle("3D Obstacle Avoidance Real-time Trajectory")
plot_widget.resize(800, 600) # Set explicit size
plot_widget.show()
print("Plot window created and shown")
# Force the window to be visible
plot_widget.raise_()
plot_widget.activateWindow()
# Small delay to ensure window appears
time.sleep(0.1)
# Update timer
timer = QTimer()
def update_plot():
if latest_results["results"] is not None:
try:
V_multi_shoot = np.array(latest_results["results"]["V_multi_shoot"])
# Extract 3D position data (first 3 elements of state)
n_x = problem.settings.sim.n_states
n_u = problem.settings.sim.n_controls
i1 = n_x
i2 = i1 + n_x * n_x
i3 = i2 + n_x * n_u
i4 = i3 + n_x * n_u
all_pos_segments = []
for i_node in range(V_multi_shoot.shape[1]):
node_data = V_multi_shoot[:, i_node]
segments_for_node = node_data.reshape(-1, i4)
pos_segments = segments_for_node[:, :3] # 3D positions
all_pos_segments.append(pos_segments)
if all_pos_segments:
full_traj = np.vstack(all_pos_segments)
if HAS_OPENGL:
plot_widget.traj_scatter.setData(pos=full_traj)
# Update obstacle positions (reset and translate for ellipsoids)
for i, ellipsoid in enumerate(plot_widget.obs_ellipsoids):
param_name = f"obstacle_center_{i + 1}"
center = problem.parameters[param_name]
ellipsoid.resetTransform()
ellipsoid.translate(*center)
else:
# 2D fallback - plot X vs Y
plot_widget.traj_curve.setData(full_traj[:, 0], full_traj[:, 1])
# Update obstacle positions in 2D
for i in range(3):
param_name = f"obstacle_center_{i + 1}"
center = problem.parameters[param_name]
plot_widget.obs_scatters[i].setData([center[0]], [center[1]])
# Update optimization metrics display
update_optimization_metrics(latest_results["results"], plot_widget.labels_dict)
except Exception as e:
print(f"Plot update error: {e}")
if "x" in latest_results["results"]:
x_traj = latest_results["results"]["x"] # Now a numpy array
if HAS_OPENGL:
plot_widget.traj_scatter.setData(pos=x_traj[:, :3])
else:
plot_widget.traj_curve.setData(x_traj[:, 0], x_traj[:, 1])
timer.timeout.connect(update_plot)
timer.start(50) # Update every 50ms
print("Starting Qt event loop...")
# Start the Qt event loop
app.exec_()
if __name__ == "__main__":
# Start optimization thread
opt_thread = threading.Thread(target=optimization_loop)
opt_thread.daemon = True
opt_thread.start()
# Start plotting in main thread (this will block and run the Qt event loop)
plot_thread_func()