Spaces:
Runtime error
Runtime error
| """ | |
| Drone Navigation Environment (simplified — no trees). | |
| A quadrotor drone flies to a nearby target in an open arena. | |
| The RL policy commands velocity (forward/left/up/turn) while a built-in PD flight | |
| controller handles low-level motor mixing. | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| # Configure MuJoCo rendering backend before importing mujoco | |
| if "MUJOCO_GL" not in os.environ and sys.platform != "darwin": | |
| os.environ.setdefault("MUJOCO_GL", "egl") | |
| import numpy as np | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| from ..models import DMControlAction, DMControlObservation, DMControlState | |
| except ImportError: | |
| from openenv.core.env_server.interfaces import Environment | |
| try: | |
| import sys as _sys | |
| from pathlib import Path as _Path | |
| _parent = str(_Path(__file__).parent.parent) | |
| if _parent not in _sys.path: | |
| _sys.path.insert(0, _parent) | |
| from models import DMControlAction, DMControlObservation, DMControlState | |
| except ImportError: | |
| try: | |
| from dm_control_env.models import ( | |
| DMControlAction, | |
| DMControlObservation, | |
| DMControlState, | |
| ) | |
| except ImportError: | |
| from envs.dm_control_env.models import ( | |
| DMControlAction, | |
| DMControlObservation, | |
| DMControlState, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| ARENA_HALF = 10.0 # arena is 20x20 m | |
| MAX_ALTITUDE = 8.0 | |
| MIN_ALTITUDE = 0.1 | |
| TARGET_RADIUS = 0.5 # success if within this distance | |
| TARGET_MIN_DIST = 2.0 # target at least this far from spawn | |
| TARGET_MAX_DIST = 4.0 # target at most this far from spawn | |
| MAX_STEPS = 1000 | |
| PHYSICS_DT = 0.002 | |
| CONTROL_DT = 0.02 # 50 Hz control | |
| # Velocity limits | |
| MAX_XY_VEL = 3.0 # m/s | |
| MAX_Z_VEL = 2.0 # m/s | |
| MAX_YAW_RATE = 2.0 # rad/s | |
| # Flight-controller PD gains | |
| KP_VEL = 4.0 | |
| KD_VEL = 1.5 | |
| KP_ATT = 8.0 | |
| KD_ATT = 2.0 | |
| # Drone physical parameters | |
| DRONE_MASS = 0.48 # total mass (body 0.4 + arms 0.08) close to XML | |
| GRAVITY = 9.81 | |
| HOVER_THRUST = DRONE_MASS * GRAVITY / 4.0 # per-motor hover | |
| ARM_LENGTH = 0.14 # distance from CoM to rotor | |
| XML_PATH = str(Path(__file__).parent / "drone_forest.xml") | |
| class DroneForestEnvironment(Environment): | |
| """Drone navigates to a nearby target in an open arena.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__( | |
| self, | |
| render_height: Optional[int] = None, | |
| render_width: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| self._model = None | |
| self._data = None | |
| self._render_height = render_height or int( | |
| os.environ.get("DMCONTROL_RENDER_HEIGHT", "512") | |
| ) | |
| self._render_width = render_width or int( | |
| os.environ.get("DMCONTROL_RENDER_WIDTH", "512") | |
| ) | |
| self._include_pixels = False | |
| self._step_count = 0 | |
| self._prev_dist = None | |
| self._target_pos = np.zeros(3) | |
| self._done = False | |
| self._rng = np.random.RandomState() | |
| self._state = DMControlState( | |
| episode_id=str(uuid4()), | |
| step_count=0, | |
| domain_name="drone_forest", | |
| task_name="navigate", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Model loading | |
| # ------------------------------------------------------------------ | |
| def _ensure_model(self): | |
| """Load MuJoCo model if not loaded yet.""" | |
| if self._model is not None: | |
| return | |
| import mujoco | |
| self._model = mujoco.MjModel.from_xml_path(XML_PATH) | |
| self._data = mujoco.MjData(self._model) | |
| # Precompute body / geom ids | |
| self._drone_body_id = mujoco.mj_name2id( | |
| self._model, mujoco.mjtObj.mjOBJ_BODY, "drone" | |
| ) | |
| self._target_body_id = mujoco.mj_name2id( | |
| self._model, mujoco.mjtObj.mjOBJ_BODY, "target" | |
| ) | |
| self._drone_body_geom_id = mujoco.mj_name2id( | |
| self._model, mujoco.mjtObj.mjOBJ_GEOM, "drone_body" | |
| ) | |
| self._ground_geom_id = mujoco.mj_name2id( | |
| self._model, mujoco.mjtObj.mjOBJ_GEOM, "ground" | |
| ) | |
| # Set state metadata | |
| self._state.action_spec = { | |
| "shape": [4], | |
| "dtype": "float64", | |
| "minimum": [-1.0, -1.0, -1.0, -1.0], | |
| "maximum": [1.0, 1.0, 1.0, 1.0], | |
| "name": "velocity_command", | |
| } | |
| self._state.observation_spec = { | |
| "position": {"shape": [3], "dtype": "float64"}, | |
| "velocity": {"shape": [3], "dtype": "float64"}, | |
| "orientation": {"shape": [3], "dtype": "float64"}, | |
| "angular_velocity": {"shape": [3], "dtype": "float64"}, | |
| "target_relative": {"shape": [3], "dtype": "float64"}, | |
| } | |
| self._state.physics_timestep = PHYSICS_DT | |
| self._state.control_timestep = CONTROL_DT | |
| # ------------------------------------------------------------------ | |
| # Target placement | |
| # ------------------------------------------------------------------ | |
| def _place_target(self): | |
| """Place target close to spawn (2-4m away).""" | |
| import mujoco | |
| angle = self._rng.uniform(0, 2 * np.pi) | |
| dist = self._rng.uniform(TARGET_MIN_DIST, TARGET_MAX_DIST) | |
| tx = dist * np.cos(angle) | |
| ty = dist * np.sin(angle) | |
| tz = self._rng.uniform(1.0, 2.5) | |
| self._target_pos = np.array([tx, ty, tz]) | |
| self._model.body_pos[self._target_body_id] = self._target_pos.copy() | |
| # Recompute derived quantities after changing body positions | |
| mujoco.mj_forward(self._model, self._data) | |
| # ------------------------------------------------------------------ | |
| # Flight controller | |
| # ------------------------------------------------------------------ | |
| def _flight_controller(self, cmd: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert velocity commands [vx, vy, vz, yaw_rate] in [-1,1] | |
| to 4 motor thrusts. | |
| """ | |
| # Scale commands | |
| vx_cmd = cmd[0] * MAX_XY_VEL | |
| vy_cmd = cmd[1] * MAX_XY_VEL | |
| vz_cmd = cmd[2] * MAX_Z_VEL | |
| yaw_rate_cmd = cmd[3] * MAX_YAW_RATE | |
| # Current state | |
| pos = self._data.qpos[:3].copy() | |
| quat = self._data.qpos[3:7].copy() # w, x, y, z | |
| vel = self._data.qvel[:3].copy() | |
| ang_vel = self._data.qvel[3:6].copy() | |
| # Extract yaw from quaternion | |
| roll, pitch, yaw = self._quat_to_euler(quat) | |
| # Rotate desired world-frame velocity into body XY | |
| cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) | |
| # World-frame desired velocity | |
| vx_world = vx_cmd * cos_yaw - vy_cmd * sin_yaw | |
| vy_world = vx_cmd * sin_yaw + vy_cmd * cos_yaw | |
| # Velocity error in world frame | |
| vx_err = vx_world - vel[0] | |
| vy_err = vy_world - vel[1] | |
| vz_err = vz_cmd - vel[2] | |
| # Desired roll/pitch from XY velocity error (small angle approx) | |
| desired_pitch = np.clip(KP_VEL * vx_err, -0.5, 0.5) | |
| desired_roll = np.clip(-KP_VEL * vy_err, -0.5, 0.5) | |
| # Attitude PD | |
| roll_err = desired_roll - roll | |
| pitch_err = desired_pitch - pitch | |
| yaw_rate_err = yaw_rate_cmd - ang_vel[2] | |
| torque_roll = KP_ATT * roll_err - KD_ATT * ang_vel[0] | |
| torque_pitch = KP_ATT * pitch_err - KD_ATT * ang_vel[1] | |
| torque_yaw = KP_ATT * yaw_rate_err | |
| # Collective thrust: hover + vertical velocity correction | |
| thrust = DRONE_MASS * GRAVITY + KP_VEL * vz_err * DRONE_MASS | |
| # Quadrotor mixer: convert thrust + torques to 4 motor thrusts | |
| # Layout: FR(+x,-y), FL(+x,+y), BR(-x,-y), BL(-x,+y) | |
| L = ARM_LENGTH | |
| t_fr = thrust / 4.0 + torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) - torque_yaw / 4.0 | |
| t_fl = thrust / 4.0 + torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) + torque_yaw / 4.0 | |
| t_br = thrust / 4.0 - torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) + torque_yaw / 4.0 | |
| t_bl = thrust / 4.0 - torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) - torque_yaw / 4.0 | |
| # Clamp to actuator range [0, 3] | |
| motors = np.clip([t_fr, t_fl, t_br, t_bl], 0.0, 3.0) | |
| return motors | |
| def _quat_to_euler(quat: np.ndarray): | |
| """Convert quaternion [w, x, y, z] to Euler angles [roll, pitch, yaw].""" | |
| w, x, y, z = quat | |
| # Roll (x-axis rotation) | |
| sinr = 2.0 * (w * x + y * z) | |
| cosr = 1.0 - 2.0 * (x * x + y * y) | |
| roll = np.arctan2(sinr, cosr) | |
| # Pitch (y-axis rotation) | |
| sinp = 2.0 * (w * y - z * x) | |
| sinp = np.clip(sinp, -1.0, 1.0) | |
| pitch = np.arcsin(sinp) | |
| # Yaw (z-axis rotation) | |
| siny = 2.0 * (w * z + x * y) | |
| cosy = 1.0 - 2.0 * (y * y + z * z) | |
| yaw = np.arctan2(siny, cosy) | |
| return roll, pitch, yaw | |
| # ------------------------------------------------------------------ | |
| # Observations | |
| # ------------------------------------------------------------------ | |
| def _get_obs(self) -> Dict[str, List[float]]: | |
| pos = self._data.qpos[:3].copy() | |
| vel = self._data.qvel[:3].copy() | |
| quat = self._data.qpos[3:7].copy() | |
| ang_vel = self._data.qvel[3:6].copy() | |
| roll, pitch, yaw = self._quat_to_euler(quat) | |
| target_rel = self._target_pos - pos | |
| return { | |
| "position": pos.tolist(), | |
| "velocity": vel.tolist(), | |
| "orientation": [float(roll), float(pitch), float(yaw)], | |
| "angular_velocity": ang_vel.tolist(), | |
| "target_relative": target_rel.tolist(), | |
| } | |
| # ------------------------------------------------------------------ | |
| # Collision detection | |
| # ------------------------------------------------------------------ | |
| def _check_collisions(self) -> bool: | |
| """Return True if drone collides with ground.""" | |
| for i in range(self._data.ncon): | |
| contact = self._data.contact[i] | |
| g1, g2 = contact.geom1, contact.geom2 | |
| pair = {g1, g2} | |
| if self._drone_body_geom_id not in pair: | |
| continue | |
| other = (pair - {self._drone_body_geom_id}).pop() | |
| if other == self._ground_geom_id: | |
| return True | |
| return False | |
| # ------------------------------------------------------------------ | |
| # Reward | |
| # ------------------------------------------------------------------ | |
| def _compute_reward(self, pos: np.ndarray) -> float: | |
| dist = np.linalg.norm(self._target_pos - pos) | |
| reward = 0.0 | |
| # +0.1 if drone moved closer to target this step, 0.0 otherwise | |
| if self._prev_dist is not None and dist < self._prev_dist: | |
| reward = 0.1 | |
| self._prev_dist = dist | |
| return float(reward) | |
| # ------------------------------------------------------------------ | |
| # Termination | |
| # ------------------------------------------------------------------ | |
| def _check_termination(self, pos: np.ndarray): | |
| """Returns (done, bonus_reward).""" | |
| dist = np.linalg.norm(self._target_pos - pos) | |
| # Success | |
| if dist < TARGET_RADIUS: | |
| return True, 100.0 | |
| # Collision | |
| if self._check_collisions(): | |
| return True, -50.0 | |
| # Out of bounds | |
| if (abs(pos[0]) > ARENA_HALF or abs(pos[1]) > ARENA_HALF or | |
| pos[2] > MAX_ALTITUDE or pos[2] < MIN_ALTITUDE): | |
| return True, -10.0 | |
| # Max steps | |
| if self._step_count >= MAX_STEPS: | |
| return True, 0.0 | |
| return False, 0.0 | |
| # ------------------------------------------------------------------ | |
| # Core interface | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| domain_name: Optional[str] = None, | |
| task_name: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| render: bool = False, | |
| **kwargs, | |
| ) -> DMControlObservation: | |
| import mujoco | |
| self._ensure_model() | |
| self._include_pixels = render | |
| if seed is not None: | |
| self._rng = np.random.RandomState(seed) | |
| # Reset data to defaults | |
| mujoco.mj_resetData(self._model, self._data) | |
| # Place target nearby | |
| self._place_target() | |
| # Place drone at origin, altitude 1.5 | |
| self._data.qpos[:3] = [0.0, 0.0, 1.5] | |
| self._data.qpos[3:7] = [1.0, 0.0, 0.0, 0.0] # identity quaternion | |
| self._data.qvel[:] = 0.0 | |
| mujoco.mj_forward(self._model, self._data) | |
| self._step_count = 0 | |
| pos = self._data.qpos[:3].copy() | |
| self._prev_dist = float(np.linalg.norm(self._target_pos - pos)) | |
| self._done = False | |
| self._state = DMControlState( | |
| episode_id=str(uuid4()), | |
| step_count=0, | |
| domain_name="drone_forest", | |
| task_name="navigate", | |
| action_spec=self._state.action_spec, | |
| observation_spec=self._state.observation_spec, | |
| physics_timestep=PHYSICS_DT, | |
| control_timestep=CONTROL_DT, | |
| ) | |
| obs = self._get_obs() | |
| pixels = self._render_pixels() if render else None | |
| return DMControlObservation( | |
| observations=obs, | |
| pixels=pixels, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| def step( | |
| self, | |
| action: DMControlAction, | |
| render: bool = False, | |
| **kwargs, | |
| ) -> DMControlObservation: | |
| import mujoco | |
| if self._model is None or self._data is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| # Clip action to [-1, 1] | |
| cmd = np.clip(np.array(action.values[:4], dtype=np.float64), -1.0, 1.0) | |
| # Run flight controller to get motor thrusts | |
| motors = self._flight_controller(cmd) | |
| # Set actuator controls | |
| self._data.ctrl[:4] = motors | |
| # Step physics for one control timestep (multiple physics substeps) | |
| n_substeps = int(CONTROL_DT / PHYSICS_DT) | |
| for _ in range(n_substeps): | |
| mujoco.mj_step(self._model, self._data) | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| pos = self._data.qpos[:3].copy() | |
| # Compute reward and check termination | |
| reward = self._compute_reward(pos) | |
| done, bonus = self._check_termination(pos) | |
| reward += bonus | |
| self._done = done | |
| obs = self._get_obs() | |
| pixels = self._render_pixels() if (render or self._include_pixels) else None | |
| return DMControlObservation( | |
| observations=obs, | |
| pixels=pixels, | |
| reward=float(reward), | |
| done=done, | |
| ) | |
| async def reset_async(self, **kwargs) -> DMControlObservation: | |
| if sys.platform == "darwin": | |
| return self.reset(**kwargs) | |
| else: | |
| import asyncio | |
| return await asyncio.to_thread(self.reset, **kwargs) | |
| async def step_async(self, action: DMControlAction, render: bool = False, **kwargs) -> DMControlObservation: | |
| if sys.platform == "darwin": | |
| return self.step(action, render=render, **kwargs) | |
| else: | |
| import asyncio | |
| return await asyncio.to_thread(self.step, action, render=render, **kwargs) | |
| # ------------------------------------------------------------------ | |
| # Rendering | |
| # ------------------------------------------------------------------ | |
| def _render_pixels(self) -> Optional[str]: | |
| try: | |
| import mujoco | |
| renderer = mujoco.Renderer(self._model, height=self._render_height, width=self._render_width) | |
| renderer.update_scene(self._data, camera="tracking") | |
| frame = renderer.render() | |
| renderer.close() | |
| from PIL import Image | |
| img = Image.fromarray(frame) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| import traceback | |
| print(f"[render error] {e}") | |
| traceback.print_exc() | |
| return None | |
| def state(self) -> DMControlState: | |
| return self._state | |
| def close(self) -> None: | |
| self._model = None | |
| self._data = None | |
| def __del__(self): | |
| try: | |
| self.close() | |
| except Exception: | |
| pass | |