Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions controls/sae_2025_ws/src/uav/uav/ModeManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from rclpy.node import Node
from px4_msgs.msg import VehicleStatus
from std_srvs.srv import Trigger
from uav import VTOL, Multicopter
from uav import VTOL, Multicopter, Payload
from uav.autonomous_modes import Mode, LandingMode
from uav.utils import Vehicle
from uav.utils import VehicleType
import yaml

VISION_NODE_PATH = "uav.vision_nodes"
Expand All @@ -31,7 +31,7 @@ def __init__(self) -> None:
self.declare_parameter("camera_offsets", [0.0, 0.0, 0.0])
self.declare_parameter("debug", False)
self.declare_parameter("servo_only", False)
self.declare_parameter("vehicle_class", Vehicle.MULTICOPTER.name)
self.declare_parameter("vehicle_class", VehicleType.MULTICOPTER.name)

mode_map = self.get_parameter("mode_map").value
if not mode_map:
Expand Down Expand Up @@ -64,24 +64,26 @@ def __init__(self) -> None:
self.active_mode = None
self.last_update_time = time()
self.start_time = self.last_update_time
# Instantiate appropriate UAV subclass based on vehicle type
if vehicle_class == Vehicle.VTOL:
# Instantiate appropriate Vehicle subclass based on vehicle type
if vehicle_class == VehicleType.VTOL:
self.uav = VTOL(self, DEBUG=debug, camera_offsets=camera_offsets)
elif vehicle_class == VehicleType.PAYLOAD:
self.uav = Payload(self, DEBUG=debug)
else:
self.uav = Multicopter(self, DEBUG=debug, camera_offsets=camera_offsets)
self.get_logger().info("Mission Node has started!")
self.setup_vision(vision_nodes)
self.setup_modes(mode_map)
self.servo_only = servo_only

def _parse_vehicle_class(self, vehicle_class) -> Vehicle:
if isinstance(vehicle_class, Vehicle):
def _parse_vehicle_class(self, vehicle_class) -> VehicleType:
if isinstance(vehicle_class, VehicleType):
return vehicle_class
if isinstance(vehicle_class, str):
try:
return Vehicle[vehicle_class.upper()]
return VehicleType[vehicle_class.upper()]
except KeyError as exc:
valid = ", ".join(v.name for v in Vehicle)
valid = ", ".join(v.name for v in VehicleType)
raise ValueError(
f"Invalid vehicle_class '{vehicle_class}'. Expected one of: {valid}"
) from exc
Expand Down
9 changes: 4 additions & 5 deletions controls/sae_2025_ws/src/uav/uav/Multicopter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ def __init__(
):
super().__init__(node, takeoff_amount, DEBUG, camera_offsets)

@property
def vehicle_type(self) -> str:
return "MULTICOPTER"

@property
def is_vtol(self) -> bool:
"""Multicopters are not VTOL."""
return False

@property
def vehicle_type(self) -> None:
"""Multicopters don't have vehicle_type (not VTOL)."""
return None

def vtol_transition_to(self, vtol_state, immediate=False):
"""Not available on multicopters."""
self.node.get_logger().warn(
Expand Down
65 changes: 65 additions & 0 deletions controls/sae_2025_ws/src/uav/uav/Payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from rclpy.node import Node
from payload_interfaces.msg import DriveCommand, MotorState
from payload_interfaces.srv import TimedDrive, ComputePidZieglerNichols
from uav.Vehicle import Vehicle


class Payload(Vehicle):
"""
Python ROS 2 client wrapper for the C++ payload node (ground car).
Communicates with the payload node via topics and services defined in payload_interfaces.
"""

def __init__(self, node: Node, payload_name: str = "payload", DEBUG: bool = False):
super().__init__(node, DEBUG)
self.payload_name = payload_name
self.motor_state: MotorState | None = None

self._drive_pub = node.create_publisher(
DriveCommand, f"/{payload_name}/cmd_drive", 10
)
self._timed_drive_client = node.create_client(
TimedDrive, f"/{payload_name}/timed_drive"
)
self._compute_pid_client = node.create_client(
ComputePidZieglerNichols, f"/{payload_name}/compute_pid"
)
self._motor_state_sub = node.create_subscription(
MotorState,
f"/{payload_name}/motor_state",
self._motor_state_callback,
10,
)

@property
def vehicle_type(self) -> str:
return "PAYLOAD"

def drive(self, linear: float, angular: float):
"""Publish a continuous drive command. linear: m/s, angular: rad/s (+left, -right)."""
msg = DriveCommand()
msg.linear = linear
msg.angular = angular
self._drive_pub.publish(msg)

def timed_drive(self, linear: float, angular: float, duration_sec: float):
"""Call the timed_drive service — drives for duration_sec then auto-stops. Returns a Future."""
req = TimedDrive.Request()
req.linear = linear
req.angular = angular
req.duration_sec = duration_sec
return self._timed_drive_client.call_async(req)

def stop(self):
"""Stop the ground vehicle."""
self.drive(0.0, 0.0)

def compute_pid(self, ku: float, pu: float):
"""Call the PID tuning service (Ziegler-Nichols). Returns a Future."""
req = ComputePidZieglerNichols.Request()
req.ku = ku
req.pu = pu
return self._compute_pid_client.call_async(req)

def _motor_state_callback(self, msg: MotorState):
self.motor_state = msg
188 changes: 7 additions & 181 deletions controls/sae_2025_ws/src/uav/uav/UAV.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from rclpy.node import Node
from uav.Vehicle import Vehicle
from px4_msgs.msg import (
OffboardControlMode,
TrajectorySetpoint,
Expand All @@ -18,9 +19,7 @@
QoSDurabilityPolicy,
)
import numpy as np
import math
from uav.px4_modes import PX4CustomMainMode, PX4CustomSubModeAuto
from uav.utils import R_earth

# Map nav_state value -> name for readable logging
_NAV_STATE_NAMES = {
Expand All @@ -35,7 +34,7 @@ def get_nav_state_str(val):
return _NAV_STATE_NAMES.get(val, str(val))


class UAV(ABC):
class UAV(Vehicle):
"""
Abstract base class for UAV control and interfacing with PX4 via ROS 2.
Subclasses: VTOL, Multicopter
Expand All @@ -44,21 +43,16 @@ class UAV(ABC):
def __init__(
self, node: Node, takeoff_amount=5.0, DEBUG=False, camera_offsets=[0, 0, 0]
):
self.node = node
self.DEBUG = DEBUG
super().__init__(node, DEBUG)
self.node.get_logger().info(f"Initializing UAV with DEBUG={DEBUG}")
self.vision_clients = {}

# Initialize necessary parameters to handle PX4 flight failures
# PX4/flight-specific state
self.flight_check = False
self.emergency_landing = False
self.failsafe = False
self.failsafe_px4 = False
self.failsafe_trigger = False
self.vehicle_status = None
self.vehicle_attitude = None
self.nav_state = None
self.arm_state = None

self.system_id = 1
self.component_id = 1
Expand All @@ -75,17 +69,12 @@ def __init__(
self.origin_set = False
self.roll = None
self.pitch = None
self.yaw = None
self.takeoff_amount = takeoff_amount
self.attempted_takeoff = False

# Initialize drone position
self.local_origin = None
self.gps_origin = None

# Store current drone position
self.global_position = None
self.local_position = None

# -------------------------
# Public commands
Expand Down Expand Up @@ -119,25 +108,6 @@ def set_origin(self):
self.camera_offsets = self.uav_to_local(self.camera_offsets)
self.origin_set = True

def distance_to_waypoint(self, coordinate_system, waypoint) -> float:
"""Calculate the distance to the current waypoint."""
if coordinate_system == "GPS":
curr_gps = self.get_gps()
return self.gps_distance_3d(
waypoint[0],
waypoint[1],
waypoint[2],
curr_gps[0],
curr_gps[1],
curr_gps[2],
)
elif coordinate_system == "LOCAL":
return np.sqrt(
(self.local_position.x - waypoint[0]) ** 2
+ (self.local_position.y - waypoint[1]) ** 2
+ (self.local_position.z - waypoint[2]) ** 2
)

def hover(self):
self._send_vehicle_command(
VehicleCommand.VEHICLE_CMD_DO_SET_MODE,
Expand Down Expand Up @@ -239,20 +209,6 @@ def publish_position_setpoint(self, coordinate, relative=False, lock_yaw=False):

self.trajectory_publisher.publish(msg)

def calculate_yaw(self, x: float, y: float) -> float:
"""Calculate the yaw angle to point towards the next waypoint."""
# Calculate relative position
dx = x - self.local_position.x
dy = y - self.local_position.y

# If very close to target (hovering), maintain current yaw to prevent spinning
# caused by noisy position estimates when dx/dy are near zero
if np.linalg.norm([dx, dy]) < 3.0 and self.yaw is not None:
return self.yaw

# Calculate yaw angle
yaw = np.arctan2(dy, dx)
return yaw

@abstractmethod
def _calculate_velocity(self, target_pos: tuple, lock_yaw: bool) -> list:
Expand Down Expand Up @@ -280,139 +236,9 @@ def publish_offboard_control_heartbeat_signal(self):
msg.timestamp = int(self.node.get_clock().now().nanoseconds / 1000)
self.offboard_mode_publisher.publish(msg)

def gps_distance_3d(self, lat1, lon1, alt1, lat2, lon2, alt2):
"""
Calculate the 3D distance in feet between two GPS points, including altitude.

Parameters:
lat1 (float): Latitude of the first point in decimal degrees.
lon1 (float): Longitude of the first point in decimal degrees.
alt1 (float): Altitude of the first point in feet above sea level.
lat2 (float): Latitude of the second point in decimal degrees.
lon2 (float): Longitude of the second point in decimal degrees.
alt2 (float): Altitude of the second point in feet above sea level.

Returns:
float: The 3D distance between the two points in feet.
"""
# Earth's radius in feet (using an average value)
curr_x, curr_y, curr_z = self.gps_to_local((lat1, lon1, alt1))
tar_x, tar_y, tar_z = self.gps_to_local((lat2, lon2, alt2))
return np.sqrt(
(curr_x - tar_x) ** 2 + (curr_y - tar_y) ** 2 + (curr_z - tar_z) ** 2
)

def gps_to_local(self, target):
"""
Convert target GPS coordinates to local NED coordinates.

Args:
target (tuple): (target_lat, target_lon, target_alt)
ref (tuple): (ref_lat, ref_lon, ref_alt) from the local position message

Returns:
tuple: (x, y, z) in the local frame where:
x is North (meters),
y is East (meters),
z is Down (meters)
"""
if self.gps_origin is None:
self.node.get_logger().error(
"gps_origin not set. Cannot convert GPS to local coordinates."
)
return None

target_lat, target_lon, target_alt = target
ref_lat, ref_lon, ref_alt = self.gps_origin

# Convert differences in latitude and longitude from degrees to radians
d_lat = math.radians(target_lat - ref_lat)
d_lon = math.radians(target_lon - ref_lon)

# Compute local displacements
x = d_lat * R_earth # North displacement
y = d_lon * R_earth * math.cos(math.radians(ref_lat)) # East displacement
z = -(target_alt - ref_alt) # Down displacement

return (x, y, z)

def uav_to_local(self, point, relative=False):
"""
Converts a point in the UAV's local frame to the global frame.

:param point: A tuple (point_x, point_y, point_z) in the UAV's local frame.
:param relative: If True, the point is relative to the current local position.
:return: A tuple (goal_x, goal_y, goal_z) representing the point in the global frame.
"""
current_pos = self.get_local_position()
point_x, point_y, point_z = point

# Rotate the x and y points according to the UAV's yaw angle.
rotated_point_x = point_x * math.cos(self.yaw) - point_y * math.sin(self.yaw)
rotated_point_y = point_x * math.sin(self.yaw) + point_y * math.cos(self.yaw)

# The z-point remains unchanged.
if relative:
return (
current_pos[0] + rotated_point_x,
current_pos[1] + rotated_point_y,
current_pos[2] + point_z,
)
else:
return (rotated_point_x, rotated_point_y, point_z)

def local_to_gps(self, local_pos):
"""
Convert a local NED coordinate to a GPS coordinate.

Args:
local_pos (tuple): (x, y, z) in meters, where:
x: North displacement,
y: East displacement,
z: Down displacement.
ref_gps (tuple): (lat, lon, alt) of the reference point (takeoff) in degrees and meters.

Returns:
tuple: (lat, lon, alt) GPS coordinate corresponding to local_pos.
"""
if self.gps_origin is None:
self.node.get_logger().error(
"gps_origin not set. Cannot convert local to GPS coordinates."
)
return None
else:
x, y, z = local_pos
lat0, lon0, alt0 = self.gps_origin

# Convert displacements from meters to degrees
dlat = (x / R_earth) * (180.0 / math.pi)
dlon = (y / (R_earth * math.cos(math.radians(lat0)))) * (180.0 / math.pi)

lat = lat0 + dlat
lon = lon0 + dlon
alt = alt0 - z # because z is down in NED
return (lat, lon, alt)

# -------------------------
# Getters / data access
# -------------------------
def get_gps(self):
if self.global_position:
return (
self.global_position.lat,
self.global_position.lon,
self.global_position.alt,
)
else:
self.node.get_logger().warn("No GPS data available.")
return None

def get_local_position(self):
if self.local_position:
return (self.local_position.x, self.local_position.y, self.local_position.z)
else:
self.node.get_logger().warn("No local position data available.")
return None
"""Alias for vehicle_to_local — kept for backwards compatibility."""
return self.vehicle_to_local(point, relative)

def _calculate_proportional_velocity(
self, direction: np.ndarray, distance: float
Expand Down
Loading
Loading