Leeps commited on
Commit
dee27a2
·
1 Parent(s): 9b47d4a

Add drone + forest sim

Browse files
Dockerfile CHANGED
@@ -71,4 +71,5 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
71
  # Run the FastAPI server
72
  # Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
73
  ENV ENABLE_WEB_INTERFACE=true
 
74
  CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
 
71
  # Run the FastAPI server
72
  # Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
73
  ENV ENABLE_WEB_INTERFACE=true
74
+ ENV OPENENV_ENVIRONMENT="drone_forest"
75
  CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  """dm_control OpenEnv server module."""
2
 
3
  from .dm_control_environment import DMControlEnvironment
 
4
 
5
- __all__ = ["DMControlEnvironment"]
 
1
  """dm_control OpenEnv server module."""
2
 
3
  from .dm_control_environment import DMControlEnvironment
4
+ from .drone_forest_environment import DroneForestEnvironment
5
 
6
+ __all__ = ["DMControlEnvironment", "DroneForestEnvironment"]
server/app.py CHANGED
@@ -21,11 +21,14 @@ Usage:
21
  uv run --project . server
22
  """
23
 
 
 
24
  try:
25
  from openenv.core.env_server.http_server import create_app
26
 
27
  from ..models import DMControlAction, DMControlObservation
28
  from .dm_control_environment import DMControlEnvironment
 
29
  except ImportError:
30
  from openenv.core.env_server.http_server import create_app
31
 
@@ -38,25 +41,40 @@ except ImportError:
38
  sys.path.insert(0, _parent)
39
  from models import DMControlAction, DMControlObservation
40
  from server.dm_control_environment import DMControlEnvironment
 
41
  except ImportError:
42
  try:
43
  from dm_control_env.models import DMControlAction, DMControlObservation
44
  from dm_control_env.server.dm_control_environment import (
45
  DMControlEnvironment,
46
  )
 
 
 
47
  except ImportError:
48
  from envs.dm_control_env.models import DMControlAction, DMControlObservation
49
  from envs.dm_control_env.server.dm_control_environment import (
50
  DMControlEnvironment,
51
  )
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Create the app with web interface
54
  # Pass the class (factory) for concurrent session support
55
  app = create_app(
56
- DMControlEnvironment,
57
  DMControlAction,
58
  DMControlObservation,
59
- env_name="dm_control_env",
60
  )
61
 
62
 
 
21
  uv run --project . server
22
  """
23
 
24
+ import os
25
+
26
  try:
27
  from openenv.core.env_server.http_server import create_app
28
 
29
  from ..models import DMControlAction, DMControlObservation
30
  from .dm_control_environment import DMControlEnvironment
31
+ from .drone_forest_environment import DroneForestEnvironment
32
  except ImportError:
33
  from openenv.core.env_server.http_server import create_app
34
 
 
41
  sys.path.insert(0, _parent)
42
  from models import DMControlAction, DMControlObservation
43
  from server.dm_control_environment import DMControlEnvironment
44
+ from server.drone_forest_environment import DroneForestEnvironment
45
  except ImportError:
46
  try:
47
  from dm_control_env.models import DMControlAction, DMControlObservation
48
  from dm_control_env.server.dm_control_environment import (
49
  DMControlEnvironment,
50
  )
51
+ from dm_control_env.server.drone_forest_environment import (
52
+ DroneForestEnvironment,
53
+ )
54
  except ImportError:
55
  from envs.dm_control_env.models import DMControlAction, DMControlObservation
56
  from envs.dm_control_env.server.dm_control_environment import (
57
  DMControlEnvironment,
58
  )
59
+ from envs.dm_control_env.server.drone_forest_environment import (
60
+ DroneForestEnvironment,
61
+ )
62
+
63
+ # Select environment based on OPENENV_ENVIRONMENT env var
64
+ if os.environ.get("OPENENV_ENVIRONMENT") == "drone_forest":
65
+ _env_cls = DroneForestEnvironment
66
+ _env_name = "drone_forest"
67
+ else:
68
+ _env_cls = DMControlEnvironment
69
+ _env_name = "dm_control_env"
70
 
71
  # Create the app with web interface
72
  # Pass the class (factory) for concurrent session support
73
  app = create_app(
74
+ _env_cls,
75
  DMControlAction,
76
  DMControlObservation,
77
+ env_name=_env_name,
78
  )
79
 
80
 
server/drone_forest.xml ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco model="drone_forest">
2
+ <compiler angle="radian"/>
3
+
4
+ <option timestep="0.002" gravity="0 0 -9.81" integrator="RK4"/>
5
+
6
+ <default>
7
+ <geom contype="1" conaffinity="1" condim="3" friction="1 0.5 0.5"/>
8
+ <joint damping="0.01"/>
9
+ </default>
10
+
11
+ <asset>
12
+ <texture name="grid" type="2d" builtin="checker" rgb1="0.3 0.6 0.3" rgb2="0.2 0.5 0.2"
13
+ width="512" height="512"/>
14
+ <material name="ground_mat" texture="grid" texrepeat="10 10" texuniform="true"
15
+ reflectance="0.1"/>
16
+ <material name="drone_body_mat" rgba="0.2 0.2 0.2 1"/>
17
+ <material name="drone_arm_mat" rgba="0.4 0.4 0.4 1"/>
18
+ <material name="rotor_mat" rgba="0.1 0.1 0.8 1"/>
19
+ <material name="trunk_mat" rgba="0.45 0.3 0.15 1"/>
20
+ <material name="canopy_mat" rgba="0.1 0.5 0.1 0.8"/>
21
+ <material name="target_mat" rgba="1.0 0.0 0.0 0.6"/>
22
+ </asset>
23
+
24
+ <worldbody>
25
+ <!-- Ground plane -->
26
+ <geom name="ground" type="plane" size="25 25 0.1" material="ground_mat"
27
+ contype="1" conaffinity="1"/>
28
+
29
+ <!-- Drone -->
30
+ <body name="drone" pos="0 0 1.5">
31
+ <freejoint name="drone_joint"/>
32
+ <inertial pos="0 0 0" mass="0.4" diaginertia="0.003 0.003 0.005"/>
33
+
34
+ <!-- Central body -->
35
+ <geom name="drone_body" type="box" size="0.06 0.06 0.025" material="drone_body_mat"
36
+ mass="0.2" contype="1" conaffinity="1"/>
37
+
38
+ <!-- Arm +X (front-right) -->
39
+ <geom name="arm_fr" type="capsule" fromto="0.04 -0.04 0 0.14 -0.14 0"
40
+ size="0.01" material="drone_arm_mat" mass="0.02" contype="0" conaffinity="0"/>
41
+ <site name="rotor_fr" pos="0.14 -0.14 0.01" size="0.04 0.005" type="cylinder"
42
+ rgba="0.1 0.1 0.8 0.5"/>
43
+
44
+ <!-- Arm +X +Y (front-left) -->
45
+ <geom name="arm_fl" type="capsule" fromto="0.04 0.04 0 0.14 0.14 0"
46
+ size="0.01" material="drone_arm_mat" mass="0.02" contype="0" conaffinity="0"/>
47
+ <site name="rotor_fl" pos="0.14 0.14 0.01" size="0.04 0.005" type="cylinder"
48
+ rgba="0.1 0.1 0.8 0.5"/>
49
+
50
+ <!-- Arm -X -Y (back-right) -->
51
+ <geom name="arm_br" type="capsule" fromto="-0.04 -0.04 0 -0.14 -0.14 0"
52
+ size="0.01" material="drone_arm_mat" mass="0.02" contype="0" conaffinity="0"/>
53
+ <site name="rotor_br" pos="-0.14 -0.14 0.01" size="0.04 0.005" type="cylinder"
54
+ rgba="0.1 0.1 0.8 0.5"/>
55
+
56
+ <!-- Arm -X +Y (back-left) -->
57
+ <geom name="arm_bl" type="capsule" fromto="-0.04 0.04 0 -0.14 0.14 0"
58
+ size="0.01" material="drone_arm_mat" mass="0.02" contype="0" conaffinity="0"/>
59
+ <site name="rotor_bl" pos="-0.14 0.14 0.01" size="0.04 0.005" type="cylinder"
60
+ rgba="0.1 0.1 0.8 0.5"/>
61
+
62
+ <!-- FPV camera (front-facing, slightly down) -->
63
+ <camera name="drone_fpv" pos="0.08 0 -0.01" xyaxes="0 -1 0 0.1 0 1"
64
+ fovy="90"/>
65
+ </body>
66
+
67
+ <!-- Target (visual only) -->
68
+ <body name="target" pos="5 5 1.5">
69
+ <geom name="target_geom" type="sphere" size="0.3" material="target_mat"
70
+ contype="0" conaffinity="0"/>
71
+ </body>
72
+
73
+ <!-- 25 tree bodies - positions set at runtime via Python -->
74
+ <body name="tree_0" pos="3 2 0">
75
+ <geom name="trunk_0" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
76
+ material="trunk_mat" contype="1" conaffinity="1"/>
77
+ <geom name="canopy_0" type="sphere" size="0.6" pos="0 0 3.2"
78
+ material="canopy_mat" contype="0" conaffinity="0"/>
79
+ </body>
80
+ <body name="tree_1" pos="5 -3 0">
81
+ <geom name="trunk_1" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
82
+ material="trunk_mat" contype="1" conaffinity="1"/>
83
+ <geom name="canopy_1" type="sphere" size="0.6" pos="0 0 3.2"
84
+ material="canopy_mat" contype="0" conaffinity="0"/>
85
+ </body>
86
+ <body name="tree_2" pos="-4 4 0">
87
+ <geom name="trunk_2" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
88
+ material="trunk_mat" contype="1" conaffinity="1"/>
89
+ <geom name="canopy_2" type="sphere" size="0.6" pos="0 0 3.2"
90
+ material="canopy_mat" contype="0" conaffinity="0"/>
91
+ </body>
92
+ <body name="tree_3" pos="7 1 0">
93
+ <geom name="trunk_3" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
94
+ material="trunk_mat" contype="1" conaffinity="1"/>
95
+ <geom name="canopy_3" type="sphere" size="0.6" pos="0 0 3.2"
96
+ material="canopy_mat" contype="0" conaffinity="0"/>
97
+ </body>
98
+ <body name="tree_4" pos="-2 -5 0">
99
+ <geom name="trunk_4" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
100
+ material="trunk_mat" contype="1" conaffinity="1"/>
101
+ <geom name="canopy_4" type="sphere" size="0.6" pos="0 0 3.2"
102
+ material="canopy_mat" contype="0" conaffinity="0"/>
103
+ </body>
104
+ <body name="tree_5" pos="1 6 0">
105
+ <geom name="trunk_5" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
106
+ material="trunk_mat" contype="1" conaffinity="1"/>
107
+ <geom name="canopy_5" type="sphere" size="0.6" pos="0 0 3.2"
108
+ material="canopy_mat" contype="0" conaffinity="0"/>
109
+ </body>
110
+ <body name="tree_6" pos="-6 -1 0">
111
+ <geom name="trunk_6" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
112
+ material="trunk_mat" contype="1" conaffinity="1"/>
113
+ <geom name="canopy_6" type="sphere" size="0.6" pos="0 0 3.2"
114
+ material="canopy_mat" contype="0" conaffinity="0"/>
115
+ </body>
116
+ <body name="tree_7" pos="4 -6 0">
117
+ <geom name="trunk_7" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
118
+ material="trunk_mat" contype="1" conaffinity="1"/>
119
+ <geom name="canopy_7" type="sphere" size="0.6" pos="0 0 3.2"
120
+ material="canopy_mat" contype="0" conaffinity="0"/>
121
+ </body>
122
+ <body name="tree_8" pos="-3 3 0">
123
+ <geom name="trunk_8" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
124
+ material="trunk_mat" contype="1" conaffinity="1"/>
125
+ <geom name="canopy_8" type="sphere" size="0.6" pos="0 0 3.2"
126
+ material="canopy_mat" contype="0" conaffinity="0"/>
127
+ </body>
128
+ <body name="tree_9" pos="6 4 0">
129
+ <geom name="trunk_9" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
130
+ material="trunk_mat" contype="1" conaffinity="1"/>
131
+ <geom name="canopy_9" type="sphere" size="0.6" pos="0 0 3.2"
132
+ material="canopy_mat" contype="0" conaffinity="0"/>
133
+ </body>
134
+ <body name="tree_10" pos="-5 -4 0">
135
+ <geom name="trunk_10" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
136
+ material="trunk_mat" contype="1" conaffinity="1"/>
137
+ <geom name="canopy_10" type="sphere" size="0.6" pos="0 0 3.2"
138
+ material="canopy_mat" contype="0" conaffinity="0"/>
139
+ </body>
140
+ <body name="tree_11" pos="2 -2 0">
141
+ <geom name="trunk_11" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
142
+ material="trunk_mat" contype="1" conaffinity="1"/>
143
+ <geom name="canopy_11" type="sphere" size="0.6" pos="0 0 3.2"
144
+ material="canopy_mat" contype="0" conaffinity="0"/>
145
+ </body>
146
+ <body name="tree_12" pos="-1 7 0">
147
+ <geom name="trunk_12" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
148
+ material="trunk_mat" contype="1" conaffinity="1"/>
149
+ <geom name="canopy_12" type="sphere" size="0.6" pos="0 0 3.2"
150
+ material="canopy_mat" contype="0" conaffinity="0"/>
151
+ </body>
152
+ <body name="tree_13" pos="8 -2 0">
153
+ <geom name="trunk_13" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
154
+ material="trunk_mat" contype="1" conaffinity="1"/>
155
+ <geom name="canopy_13" type="sphere" size="0.6" pos="0 0 3.2"
156
+ material="canopy_mat" contype="0" conaffinity="0"/>
157
+ </body>
158
+ <body name="tree_14" pos="-7 5 0">
159
+ <geom name="trunk_14" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
160
+ material="trunk_mat" contype="1" conaffinity="1"/>
161
+ <geom name="canopy_14" type="sphere" size="0.6" pos="0 0 3.2"
162
+ material="canopy_mat" contype="0" conaffinity="0"/>
163
+ </body>
164
+ <body name="tree_15" pos="3 7 0">
165
+ <geom name="trunk_15" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
166
+ material="trunk_mat" contype="1" conaffinity="1"/>
167
+ <geom name="canopy_15" type="sphere" size="0.6" pos="0 0 3.2"
168
+ material="canopy_mat" contype="0" conaffinity="0"/>
169
+ </body>
170
+ <body name="tree_16" pos="-4 -7 0">
171
+ <geom name="trunk_16" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
172
+ material="trunk_mat" contype="1" conaffinity="1"/>
173
+ <geom name="canopy_16" type="sphere" size="0.6" pos="0 0 3.2"
174
+ material="canopy_mat" contype="0" conaffinity="0"/>
175
+ </body>
176
+ <body name="tree_17" pos="7 -5 0">
177
+ <geom name="trunk_17" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
178
+ material="trunk_mat" contype="1" conaffinity="1"/>
179
+ <geom name="canopy_17" type="sphere" size="0.6" pos="0 0 3.2"
180
+ material="canopy_mat" contype="0" conaffinity="0"/>
181
+ </body>
182
+ <body name="tree_18" pos="-8 0 0">
183
+ <geom name="trunk_18" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
184
+ material="trunk_mat" contype="1" conaffinity="1"/>
185
+ <geom name="canopy_18" type="sphere" size="0.6" pos="0 0 3.2"
186
+ material="canopy_mat" contype="0" conaffinity="0"/>
187
+ </body>
188
+ <body name="tree_19" pos="0 -8 0">
189
+ <geom name="trunk_19" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
190
+ material="trunk_mat" contype="1" conaffinity="1"/>
191
+ <geom name="canopy_19" type="sphere" size="0.6" pos="0 0 3.2"
192
+ material="canopy_mat" contype="0" conaffinity="0"/>
193
+ </body>
194
+ <body name="tree_20" pos="5 8 0">
195
+ <geom name="trunk_20" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
196
+ material="trunk_mat" contype="1" conaffinity="1"/>
197
+ <geom name="canopy_20" type="sphere" size="0.6" pos="0 0 3.2"
198
+ material="canopy_mat" contype="0" conaffinity="0"/>
199
+ </body>
200
+ <body name="tree_21" pos="-6 6 0">
201
+ <geom name="trunk_21" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
202
+ material="trunk_mat" contype="1" conaffinity="1"/>
203
+ <geom name="canopy_21" type="sphere" size="0.6" pos="0 0 3.2"
204
+ material="canopy_mat" contype="0" conaffinity="0"/>
205
+ </body>
206
+ <body name="tree_22" pos="8 3 0">
207
+ <geom name="trunk_22" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
208
+ material="trunk_mat" contype="1" conaffinity="1"/>
209
+ <geom name="canopy_22" type="sphere" size="0.6" pos="0 0 3.2"
210
+ material="canopy_mat" contype="0" conaffinity="0"/>
211
+ </body>
212
+ <body name="tree_23" pos="-2 8 0">
213
+ <geom name="trunk_23" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
214
+ material="trunk_mat" contype="1" conaffinity="1"/>
215
+ <geom name="canopy_23" type="sphere" size="0.6" pos="0 0 3.2"
216
+ material="canopy_mat" contype="0" conaffinity="0"/>
217
+ </body>
218
+ <body name="tree_24" pos="1 -4 0">
219
+ <geom name="trunk_24" type="cylinder" size="0.15 1.5" pos="0 0 1.5"
220
+ material="trunk_mat" contype="1" conaffinity="1"/>
221
+ <geom name="canopy_24" type="sphere" size="0.6" pos="0 0 3.2"
222
+ material="canopy_mat" contype="0" conaffinity="0"/>
223
+ </body>
224
+ </worldbody>
225
+
226
+ <!-- Site-based thrust actuators: force applied in site-local z (tilts with drone) -->
227
+ <actuator>
228
+ <general name="thrust_fr" site="rotor_fr" gear="0 0 1 0 0 0" ctrllimited="true"
229
+ ctrlrange="0 3" gainprm="1" biasprm="0 0 0"/>
230
+ <general name="thrust_fl" site="rotor_fl" gear="0 0 1 0 0 0" ctrllimited="true"
231
+ ctrlrange="0 3" gainprm="1" biasprm="0 0 0"/>
232
+ <general name="thrust_br" site="rotor_br" gear="0 0 1 0 0 0" ctrllimited="true"
233
+ ctrlrange="0 3" gainprm="1" biasprm="0 0 0"/>
234
+ <general name="thrust_bl" site="rotor_bl" gear="0 0 1 0 0 0" ctrllimited="true"
235
+ ctrlrange="0 3" gainprm="1" biasprm="0 0 0"/>
236
+ </actuator>
237
+
238
+ <!-- Overhead tracking camera -->
239
+ <camera name="tracking" pos="0 0 25" xyaxes="1 0 0 0 1 0" mode="trackcom"
240
+ fovy="60"/>
241
+ </mujoco>
server/drone_forest_environment.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Drone Forest Navigation Environment.
3
+
4
+ A quadrotor drone navigates through a forest of columns (trees) to reach a target.
5
+ The RL policy commands velocity (forward/left/up/turn) while a built-in PD flight
6
+ controller handles low-level motor mixing.
7
+ """
8
+
9
+ import base64
10
+ import io
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Optional
15
+ from uuid import uuid4
16
+
17
+ # Configure MuJoCo rendering backend before importing mujoco
18
+ if "MUJOCO_GL" not in os.environ and sys.platform != "darwin":
19
+ os.environ.setdefault("MUJOCO_GL", "egl")
20
+
21
+ import numpy as np
22
+
23
+ try:
24
+ from openenv.core.env_server.interfaces import Environment
25
+
26
+ from ..models import DMControlAction, DMControlObservation, DMControlState
27
+ except ImportError:
28
+ from openenv.core.env_server.interfaces import Environment
29
+
30
+ try:
31
+ import sys as _sys
32
+ from pathlib import Path as _Path
33
+
34
+ _parent = str(_Path(__file__).parent.parent)
35
+ if _parent not in _sys.path:
36
+ _sys.path.insert(0, _parent)
37
+ from models import DMControlAction, DMControlObservation, DMControlState
38
+ except ImportError:
39
+ try:
40
+ from dm_control_env.models import (
41
+ DMControlAction,
42
+ DMControlObservation,
43
+ DMControlState,
44
+ )
45
+ except ImportError:
46
+ from envs.dm_control_env.models import (
47
+ DMControlAction,
48
+ DMControlObservation,
49
+ DMControlState,
50
+ )
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Constants
54
+ # ---------------------------------------------------------------------------
55
+ NUM_TREES = 25
56
+ ARENA_HALF = 10.0 # arena is 20x20 m
57
+ MAX_ALTITUDE = 8.0
58
+ MIN_ALTITUDE = 0.1
59
+ TARGET_RADIUS = 0.5 # success if within this distance
60
+ TREE_MIN_SPACING = 1.5 # min distance between tree centres
61
+ SPAWN_CLEAR_RADIUS = 2.0 # keep trees away from spawn
62
+ TARGET_MIN_DIST = 5.0 # target at least this far from spawn
63
+ MAX_STEPS = 1000
64
+ PHYSICS_DT = 0.002
65
+ CONTROL_DT = 0.02 # 50 Hz control
66
+
67
+ # Velocity limits
68
+ MAX_XY_VEL = 3.0 # m/s
69
+ MAX_Z_VEL = 2.0 # m/s
70
+ MAX_YAW_RATE = 2.0 # rad/s
71
+
72
+ # Flight-controller PD gains
73
+ KP_VEL = 4.0
74
+ KD_VEL = 1.5
75
+ KP_ATT = 8.0
76
+ KD_ATT = 2.0
77
+
78
+ # Drone physical parameters
79
+ DRONE_MASS = 0.48 # total mass (body 0.4 + arms 0.08) close to XML
80
+ GRAVITY = 9.81
81
+ HOVER_THRUST = DRONE_MASS * GRAVITY / 4.0 # per-motor hover
82
+ ARM_LENGTH = 0.14 # distance from CoM to rotor
83
+
84
+ XML_PATH = str(Path(__file__).parent / "drone_forest.xml")
85
+
86
+
87
+ class DroneForestEnvironment(Environment):
88
+ """Drone navigates a randomised forest of columns to reach a target."""
89
+
90
+ SUPPORTS_CONCURRENT_SESSIONS = True
91
+
92
+ def __init__(
93
+ self,
94
+ render_height: Optional[int] = None,
95
+ render_width: Optional[int] = None,
96
+ **kwargs,
97
+ ):
98
+ self._model = None
99
+ self._data = None
100
+ self._render_height = render_height or int(
101
+ os.environ.get("DMCONTROL_RENDER_HEIGHT", "480")
102
+ )
103
+ self._render_width = render_width or int(
104
+ os.environ.get("DMCONTROL_RENDER_WIDTH", "640")
105
+ )
106
+ self._include_pixels = False
107
+ self._step_count = 0
108
+ self._prev_dist = None
109
+ self._tree_positions: List[np.ndarray] = []
110
+ self._target_pos = np.zeros(3)
111
+ self._done = False
112
+ self._rng = np.random.RandomState()
113
+
114
+ self._state = DMControlState(
115
+ episode_id=str(uuid4()),
116
+ step_count=0,
117
+ domain_name="drone_forest",
118
+ task_name="navigate",
119
+ )
120
+
121
+ # ------------------------------------------------------------------
122
+ # Model loading
123
+ # ------------------------------------------------------------------
124
+ def _ensure_model(self):
125
+ """Load MuJoCo model if not loaded yet."""
126
+ if self._model is not None:
127
+ return
128
+ import mujoco
129
+ self._model = mujoco.MjModel.from_xml_path(XML_PATH)
130
+ self._data = mujoco.MjData(self._model)
131
+ # Precompute body / geom ids
132
+ self._drone_body_id = mujoco.mj_name2id(
133
+ self._model, mujoco.mjtObj.mjOBJ_BODY, "drone"
134
+ )
135
+ self._target_body_id = mujoco.mj_name2id(
136
+ self._model, mujoco.mjtObj.mjOBJ_BODY, "target"
137
+ )
138
+ self._tree_body_ids = [
139
+ mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_BODY, f"tree_{i}")
140
+ for i in range(NUM_TREES)
141
+ ]
142
+ self._trunk_geom_ids = [
143
+ mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_GEOM, f"trunk_{i}")
144
+ for i in range(NUM_TREES)
145
+ ]
146
+ self._drone_body_geom_id = mujoco.mj_name2id(
147
+ self._model, mujoco.mjtObj.mjOBJ_GEOM, "drone_body"
148
+ )
149
+ self._ground_geom_id = mujoco.mj_name2id(
150
+ self._model, mujoco.mjtObj.mjOBJ_GEOM, "ground"
151
+ )
152
+
153
+ # Set state metadata
154
+ self._state.action_spec = {
155
+ "shape": [4],
156
+ "dtype": "float64",
157
+ "minimum": [-1.0, -1.0, -1.0, -1.0],
158
+ "maximum": [1.0, 1.0, 1.0, 1.0],
159
+ "name": "velocity_command",
160
+ }
161
+ self._state.observation_spec = {
162
+ "position": {"shape": [3], "dtype": "float64"},
163
+ "velocity": {"shape": [3], "dtype": "float64"},
164
+ "orientation": {"shape": [3], "dtype": "float64"},
165
+ "angular_velocity": {"shape": [3], "dtype": "float64"},
166
+ "target_relative": {"shape": [3], "dtype": "float64"},
167
+ "obstacle_distances": {"shape": [8], "dtype": "float64"},
168
+ }
169
+ self._state.physics_timestep = PHYSICS_DT
170
+ self._state.control_timestep = CONTROL_DT
171
+
172
+ # ------------------------------------------------------------------
173
+ # Forest randomisation
174
+ # ------------------------------------------------------------------
175
+ def _randomise_forest(self):
176
+ """Place trees and target using rejection sampling."""
177
+ import mujoco
178
+
179
+ positions = []
180
+ attempts = 0
181
+ while len(positions) < NUM_TREES and attempts < 5000:
182
+ x = self._rng.uniform(-ARENA_HALF + 1, ARENA_HALF - 1)
183
+ y = self._rng.uniform(-ARENA_HALF + 1, ARENA_HALF - 1)
184
+ # Keep clear of spawn
185
+ if np.sqrt(x ** 2 + y ** 2) < SPAWN_CLEAR_RADIUS:
186
+ attempts += 1
187
+ continue
188
+ # Min spacing from existing trees
189
+ ok = True
190
+ for p in positions:
191
+ if np.sqrt((x - p[0]) ** 2 + (y - p[1]) ** 2) < TREE_MIN_SPACING:
192
+ ok = False
193
+ break
194
+ if ok:
195
+ positions.append(np.array([x, y]))
196
+ attempts += 1
197
+
198
+ # Pad with far-away positions if we didn't get enough
199
+ while len(positions) < NUM_TREES:
200
+ positions.append(np.array([100.0, 100.0]))
201
+
202
+ self._tree_positions = positions
203
+
204
+ # Set tree body positions in the model
205
+ for i, pos in enumerate(positions):
206
+ body_id = self._tree_body_ids[i]
207
+ self._model.body_pos[body_id] = [pos[0], pos[1], 0.0]
208
+
209
+ # Place target: at least TARGET_MIN_DIST from origin, away from trees
210
+ for _ in range(1000):
211
+ angle = self._rng.uniform(0, 2 * np.pi)
212
+ dist = self._rng.uniform(TARGET_MIN_DIST, ARENA_HALF - 2)
213
+ tx, ty = dist * np.cos(angle), dist * np.sin(angle)
214
+ tz = self._rng.uniform(1.0, 3.0)
215
+ # Check clearance from trees
216
+ clear = True
217
+ for p in positions[:NUM_TREES]:
218
+ if np.sqrt((tx - p[0]) ** 2 + (ty - p[1]) ** 2) < 1.5:
219
+ clear = False
220
+ break
221
+ if clear:
222
+ break
223
+
224
+ self._target_pos = np.array([tx, ty, tz])
225
+ self._model.body_pos[self._target_body_id] = self._target_pos.copy()
226
+
227
+ # Recompute derived quantities after changing body positions
228
+ mujoco.mj_forward(self._model, self._data)
229
+
230
+ # ------------------------------------------------------------------
231
+ # Flight controller
232
+ # ------------------------------------------------------------------
233
+ def _flight_controller(self, cmd: np.ndarray) -> np.ndarray:
234
+ """
235
+ Convert velocity commands [vx, vy, vz, yaw_rate] in [-1,1]
236
+ to 4 motor thrusts.
237
+ """
238
+ # Scale commands
239
+ vx_cmd = cmd[0] * MAX_XY_VEL
240
+ vy_cmd = cmd[1] * MAX_XY_VEL
241
+ vz_cmd = cmd[2] * MAX_Z_VEL
242
+ yaw_rate_cmd = cmd[3] * MAX_YAW_RATE
243
+
244
+ # Current state
245
+ pos = self._data.qpos[:3].copy()
246
+ quat = self._data.qpos[3:7].copy() # w, x, y, z
247
+ vel = self._data.qvel[:3].copy()
248
+ ang_vel = self._data.qvel[3:6].copy()
249
+
250
+ # Extract yaw from quaternion
251
+ roll, pitch, yaw = self._quat_to_euler(quat)
252
+
253
+ # Rotate desired world-frame velocity into body XY
254
+ cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw)
255
+ # World-frame desired velocity
256
+ vx_world = vx_cmd * cos_yaw - vy_cmd * sin_yaw
257
+ vy_world = vx_cmd * sin_yaw + vy_cmd * cos_yaw
258
+
259
+ # Velocity error in world frame
260
+ vx_err = vx_world - vel[0]
261
+ vy_err = vy_world - vel[1]
262
+ vz_err = vz_cmd - vel[2]
263
+
264
+ # Desired roll/pitch from XY velocity error (small angle approx)
265
+ desired_pitch = np.clip(KP_VEL * vx_err, -0.5, 0.5)
266
+ desired_roll = np.clip(-KP_VEL * vy_err, -0.5, 0.5)
267
+
268
+ # Attitude PD
269
+ roll_err = desired_roll - roll
270
+ pitch_err = desired_pitch - pitch
271
+ yaw_rate_err = yaw_rate_cmd - ang_vel[2]
272
+
273
+ torque_roll = KP_ATT * roll_err - KD_ATT * ang_vel[0]
274
+ torque_pitch = KP_ATT * pitch_err - KD_ATT * ang_vel[1]
275
+ torque_yaw = KP_ATT * yaw_rate_err
276
+
277
+ # Collective thrust: hover + vertical velocity correction
278
+ thrust = DRONE_MASS * GRAVITY + KP_VEL * vz_err * DRONE_MASS
279
+
280
+ # Quadrotor mixer: convert thrust + torques to 4 motor thrusts
281
+ # Layout: FR(+x,-y), FL(+x,+y), BR(-x,-y), BL(-x,+y)
282
+ L = ARM_LENGTH
283
+ t_fr = thrust / 4.0 + torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) - torque_yaw / 4.0
284
+ t_fl = thrust / 4.0 + torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) + torque_yaw / 4.0
285
+ t_br = thrust / 4.0 - torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) + torque_yaw / 4.0
286
+ t_bl = thrust / 4.0 - torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) - torque_yaw / 4.0
287
+
288
+ # Clamp to actuator range [0, 3]
289
+ motors = np.clip([t_fr, t_fl, t_br, t_bl], 0.0, 3.0)
290
+ return motors
291
+
292
+ @staticmethod
293
+ def _quat_to_euler(quat: np.ndarray):
294
+ """Convert quaternion [w, x, y, z] to Euler angles [roll, pitch, yaw]."""
295
+ w, x, y, z = quat
296
+ # Roll (x-axis rotation)
297
+ sinr = 2.0 * (w * x + y * z)
298
+ cosr = 1.0 - 2.0 * (x * x + y * y)
299
+ roll = np.arctan2(sinr, cosr)
300
+ # Pitch (y-axis rotation)
301
+ sinp = 2.0 * (w * y - z * x)
302
+ sinp = np.clip(sinp, -1.0, 1.0)
303
+ pitch = np.arcsin(sinp)
304
+ # Yaw (z-axis rotation)
305
+ siny = 2.0 * (w * z + x * y)
306
+ cosy = 1.0 - 2.0 * (y * y + z * z)
307
+ yaw = np.arctan2(siny, cosy)
308
+ return roll, pitch, yaw
309
+
310
+ # ------------------------------------------------------------------
311
+ # Observations
312
+ # ------------------------------------------------------------------
313
+ def _get_obs(self) -> Dict[str, List[float]]:
314
+ pos = self._data.qpos[:3].copy()
315
+ vel = self._data.qvel[:3].copy()
316
+ quat = self._data.qpos[3:7].copy()
317
+ ang_vel = self._data.qvel[3:6].copy()
318
+ roll, pitch, yaw = self._quat_to_euler(quat)
319
+
320
+ target_rel = self._target_pos - pos
321
+
322
+ # 8 nearest obstacle distances (XY plane, from drone position)
323
+ dists = []
324
+ for tp in self._tree_positions:
325
+ dx = tp[0] - pos[0]
326
+ dy = tp[1] - pos[1]
327
+ dists.append(np.sqrt(dx ** 2 + dy ** 2))
328
+ dists.sort()
329
+ obstacle_distances = dists[:8]
330
+ # Pad if fewer than 8
331
+ while len(obstacle_distances) < 8:
332
+ obstacle_distances.append(50.0)
333
+
334
+ return {
335
+ "position": pos.tolist(),
336
+ "velocity": vel.tolist(),
337
+ "orientation": [float(roll), float(pitch), float(yaw)],
338
+ "angular_velocity": ang_vel.tolist(),
339
+ "target_relative": target_rel.tolist(),
340
+ "obstacle_distances": obstacle_distances,
341
+ }
342
+
343
+ # ------------------------------------------------------------------
344
+ # Collision detection
345
+ # ------------------------------------------------------------------
346
+ def _check_collisions(self) -> bool:
347
+ """Return True if drone collides with any tree trunk or ground."""
348
+ import mujoco
349
+ for i in range(self._data.ncon):
350
+ contact = self._data.contact[i]
351
+ g1, g2 = contact.geom1, contact.geom2
352
+ pair = {g1, g2}
353
+ if self._drone_body_geom_id not in pair:
354
+ continue
355
+ other = (pair - {self._drone_body_geom_id}).pop()
356
+ if other == self._ground_geom_id or other in self._trunk_geom_ids:
357
+ return True
358
+ return False
359
+
360
+ # ------------------------------------------------------------------
361
+ # Reward
362
+ # ------------------------------------------------------------------
363
+ def _compute_reward(self, pos: np.ndarray) -> float:
364
+ dist = np.linalg.norm(self._target_pos - pos)
365
+ reward = 0.0
366
+
367
+ # Shaping: reward for getting closer
368
+ if self._prev_dist is not None:
369
+ reward += 1.0 * (self._prev_dist - dist)
370
+ self._prev_dist = dist
371
+
372
+ # Time pressure
373
+ reward -= 0.01
374
+
375
+ return float(reward)
376
+
377
+ # ------------------------------------------------------------------
378
+ # Termination
379
+ # ------------------------------------------------------------------
380
+ def _check_termination(self, pos: np.ndarray):
381
+ """Returns (done, bonus_reward)."""
382
+ dist = np.linalg.norm(self._target_pos - pos)
383
+
384
+ # Success
385
+ if dist < TARGET_RADIUS:
386
+ return True, 100.0
387
+
388
+ # Collision
389
+ if self._check_collisions():
390
+ return True, -50.0
391
+
392
+ # Out of bounds
393
+ if (abs(pos[0]) > ARENA_HALF or abs(pos[1]) > ARENA_HALF or
394
+ pos[2] > MAX_ALTITUDE or pos[2] < MIN_ALTITUDE):
395
+ return True, -10.0
396
+
397
+ # Max steps
398
+ if self._step_count >= MAX_STEPS:
399
+ return True, 0.0
400
+
401
+ return False, 0.0
402
+
403
+ # ------------------------------------------------------------------
404
+ # Core interface
405
+ # ------------------------------------------------------------------
406
+ def reset(
407
+ self,
408
+ domain_name: Optional[str] = None,
409
+ task_name: Optional[str] = None,
410
+ seed: Optional[int] = None,
411
+ render: bool = False,
412
+ **kwargs,
413
+ ) -> DMControlObservation:
414
+ import mujoco
415
+
416
+ self._ensure_model()
417
+ self._include_pixels = render
418
+
419
+ if seed is not None:
420
+ self._rng = np.random.RandomState(seed)
421
+
422
+ # Reset data to defaults
423
+ mujoco.mj_resetData(self._model, self._data)
424
+
425
+ # Randomise forest layout
426
+ self._randomise_forest()
427
+
428
+ # Place drone at origin, altitude 1.5
429
+ self._data.qpos[:3] = [0.0, 0.0, 1.5]
430
+ self._data.qpos[3:7] = [1.0, 0.0, 0.0, 0.0] # identity quaternion
431
+ self._data.qvel[:] = 0.0
432
+
433
+ mujoco.mj_forward(self._model, self._data)
434
+
435
+ self._step_count = 0
436
+ pos = self._data.qpos[:3].copy()
437
+ self._prev_dist = float(np.linalg.norm(self._target_pos - pos))
438
+ self._done = False
439
+
440
+ self._state = DMControlState(
441
+ episode_id=str(uuid4()),
442
+ step_count=0,
443
+ domain_name="drone_forest",
444
+ task_name="navigate",
445
+ action_spec=self._state.action_spec,
446
+ observation_spec=self._state.observation_spec,
447
+ physics_timestep=PHYSICS_DT,
448
+ control_timestep=CONTROL_DT,
449
+ )
450
+
451
+ obs = self._get_obs()
452
+ pixels = self._render_pixels() if render else None
453
+
454
+ return DMControlObservation(
455
+ observations=obs,
456
+ pixels=pixels,
457
+ reward=0.0,
458
+ done=False,
459
+ )
460
+
461
+ def step(
462
+ self,
463
+ action: DMControlAction,
464
+ render: bool = False,
465
+ **kwargs,
466
+ ) -> DMControlObservation:
467
+ import mujoco
468
+
469
+ if self._model is None or self._data is None:
470
+ raise RuntimeError("Environment not initialized. Call reset() first.")
471
+
472
+ if self._done:
473
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
474
+
475
+ # Clip action to [-1, 1]
476
+ cmd = np.clip(np.array(action.values[:4], dtype=np.float64), -1.0, 1.0)
477
+
478
+ # Run flight controller to get motor thrusts
479
+ motors = self._flight_controller(cmd)
480
+
481
+ # Set actuator controls
482
+ self._data.ctrl[:4] = motors
483
+
484
+ # Step physics for one control timestep (multiple physics substeps)
485
+ n_substeps = int(CONTROL_DT / PHYSICS_DT)
486
+ for _ in range(n_substeps):
487
+ mujoco.mj_step(self._model, self._data)
488
+
489
+ self._step_count += 1
490
+ self._state.step_count = self._step_count
491
+
492
+ pos = self._data.qpos[:3].copy()
493
+
494
+ # Compute reward and check termination
495
+ reward = self._compute_reward(pos)
496
+ done, bonus = self._check_termination(pos)
497
+ reward += bonus
498
+ self._done = done
499
+
500
+ obs = self._get_obs()
501
+ pixels = self._render_pixels() if (render or self._include_pixels) else None
502
+
503
+ return DMControlObservation(
504
+ observations=obs,
505
+ pixels=pixels,
506
+ reward=float(reward),
507
+ done=done,
508
+ )
509
+
510
+ async def reset_async(self, **kwargs) -> DMControlObservation:
511
+ if sys.platform == "darwin":
512
+ return self.reset(**kwargs)
513
+ else:
514
+ import asyncio
515
+ return await asyncio.to_thread(self.reset, **kwargs)
516
+
517
+ async def step_async(self, action: DMControlAction, render: bool = False, **kwargs) -> DMControlObservation:
518
+ if sys.platform == "darwin":
519
+ return self.step(action, render=render, **kwargs)
520
+ else:
521
+ import asyncio
522
+ return await asyncio.to_thread(self.step, action, render=render, **kwargs)
523
+
524
+ # ------------------------------------------------------------------
525
+ # Rendering
526
+ # ------------------------------------------------------------------
527
+ def _render_pixels(self) -> Optional[str]:
528
+ try:
529
+ import mujoco
530
+ renderer = mujoco.Renderer(self._model, height=self._render_height, width=self._render_width)
531
+ renderer.update_scene(self._data, camera="tracking")
532
+ frame = renderer.render()
533
+ renderer.close()
534
+ from PIL import Image
535
+ img = Image.fromarray(frame)
536
+ buf = io.BytesIO()
537
+ img.save(buf, format="PNG")
538
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
539
+ except Exception:
540
+ return None
541
+
542
+ @property
543
+ def state(self) -> DMControlState:
544
+ return self._state
545
+
546
+ def close(self) -> None:
547
+ self._model = None
548
+ self._data = None
549
+
550
+ def __del__(self):
551
+ try:
552
+ self.close()
553
+ except Exception:
554
+ pass