Controller / training /data_loader.py
Gen-HVAC's picture
Upload 4 files
1641a08 verified
import os
import glob
import re
import hashlib
from typing import Dict, List, Optional, Any, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import json
# CONFIG & REGISTRY
DROP_OBS_KEYS = []
DATA_DIR = "TrajectoryData_from_docker"
INDEX_CACHE_PATH = os.path.join(DATA_DIR, "episode_index_cache_topk.json")
NORM_CACHE_PATH = os.path.join(DATA_DIR, "norm_stats_v_topk.npz")
PAD_ID = 0
UNK_ID = 1
SENSOR_START_ID = 2
ACTION_START_ID = 300
VOCAB_SIZE = 512
CONTEXT_LEN = 48
MAX_TOKENS_PER_STEP = 64
MAX_ZONES = 32
PHYSICS_HORIZON = 16
SEED = 42
USE_TOPK = True
TOPK_FRAC = 0.8
TOPK_MODE = "filter"
TOPK_ON = "energy"
TOPK_BOOST = 3.0
# --- Action Discretization ---
NUM_ACTION_BINS = 64
HTG_LOW, HTG_HIGH = 15.0, 30.0
CLG_LOW, CLG_HIGH = 15.0, 30.0
# --- Normalization & Scaling ---
USE_NORMALIZATION = True
ACTION_VALUE_INPUT_MODE = "prev"
ACTION_VALUE_MASK_CONST = 0.0
COMFORT_SCALE = 1.0
# --- Preference conditioning ---
PREF_MODE = "sample"
PREF_FIXED_LAMBDA = 0.5
PREF_BETA_A = 5.0
PREF_BETA_B = 2.0
ZONE_SRC_REGEX = 1
ZONE_SRC_PAREN = 2
ZONE_SRC_CORE_PERIM = 3
ZONE_SRC_HASH = 4
HVAC_KEYWORD_MAP = {
# Sensors (2..299)
"temp": 10, "t_in": 10, "temperature": 10,
"humidity": 11, "rh": 11,
"co2": 12, "ppm": 12,
"power": 13, "energy": 13, "kw": 13,
"occupancy": 14, "occ": 14, "people": 14,
"solar": 15, "rad": 15, "radiation": 15,
"outdoor": 16, "site": 16, "environment": 16,
"pressure": 17, "flow": 18, "fan": 19, "speed": 19,
# Actions (offset from ACTION_START_ID)
"setpoint": 10, "stpt": 10,
"damper": 11, "position": 11, "valve": 12,
}
# ============================================================
# HELPER
# ============================================================
def compute_comfort_indices_from_state_keys(state_keys: List[str]) -> List[int]:
kl = [str(k).lower() for k in state_keys]
any_idx = [i for i, k in enumerate(kl)
if ("ash55" in k and "notcomfortable" in k and "any" in k)]
if len(any_idx) > 0:
return any_idx
return [i for i, k in enumerate(kl)
if ("ash55" in k and "notcomfortable" in k)]
def extract_zone_id_with_source(name_lower: str) -> Tuple[int, int]:
m = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', name_lower)
if m:
zid = int(m.group(1))
zid = min(max(zid, 0), MAX_ZONES - 1)
return zid, ZONE_SRC_REGEX
parens = re.findall(r'\(([^)]+)\)', name_lower)
for chunk in parens:
m2 = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', chunk)
if m2: return min(max(int(m2.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_PAREN
m4 = re.search(r'(?:perimeter|perim|core)[_\s\-]*?(?:zn[_\s\-]*)?(\d+)\b', name_lower)
if m4:
return min(max(int(m4.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_CORE_PERIM
h = int(hashlib.md5(name_lower.encode()).hexdigest(), 16)
return 1 + (h % max(1, (MAX_ZONES - 1))), ZONE_SRC_HASH
def parse_feature_identity(name: str, is_action: bool = False) -> Tuple[int, int, int]:
name_lower = str(name).lower()
zone_id, zone_src = extract_zone_id_with_source(name_lower)
found_id = UNK_ID
for key, val in HVAC_KEYWORD_MAP.items():
if key in name_lower:
found_id = val
break
if found_id == UNK_ID:
hash_val = int(hashlib.md5(name_lower.encode()).hexdigest(), 16)
found_id = 50 + (hash_val % 50)
final_id = (ACTION_START_ID if is_action else SENSOR_START_ID) + found_id
if final_id >= VOCAB_SIZE: final_id = UNK_ID
return final_id, zone_id, zone_src
def discretize_actions_to_bins(actions: np.ndarray, action_keys: List[str]) -> np.ndarray:
out = np.zeros_like(actions, dtype=np.int64)
for j, k in enumerate(action_keys):
kl = k.lower()
if "clg" in kl or "cool" in kl: lo, hi = CLG_LOW, CLG_HIGH
else: lo, hi = HTG_LOW, HTG_HIGH
a = np.clip(actions[:, j], lo, hi)
x = (a - lo) / (hi - lo + 1e-12)
bins = np.rint(x * (NUM_ACTION_BINS - 1)).astype(np.int64)
out[:, j] = np.clip(bins, 0, NUM_ACTION_BINS - 1)
return out
def discounted_cumsum(x: np.ndarray, gamma: float = 1.0) -> np.ndarray:
y = np.zeros_like(x, dtype=np.float32)
running = 0.0
for t in range(len(x)-1, -1, -1):
running = x[t] + gamma * running
y[t] = running
return y
def _mix_u64(x: int) -> int:
x &= 0xFFFFFFFFFFFFFFFF
x ^= (x >> 33)
x = (x * 0xff51afd7ed558ccd) & 0xFFFFFFFFFFFFFFFF
x ^= (x >> 33)
x = (x * 0xc4ceb9fe1a85ec53) & 0xFFFFFFFFFFFFFFFF
x ^= (x >> 33)
return x & 0xFFFFFFFFFFFFFFFF
def dataset_signature(npz_paths: List[str]) -> str:
parts = []
for p in npz_paths:
try:
st = os.stat(p)
parts.append(f"{p}|{st.st_size}|{int(st.st_mtime)}")
except FileNotFoundError:
parts.append(f"{p}|missing")
raw = "\n".join(parts).encode("utf-8")
return hashlib.md5(raw).hexdigest()
def compute_occupancy_indices_from_state_keys(state_keys: List[str]) -> List[int]:
kl = [str(k).lower() for k in state_keys]
return [i for i, k in enumerate(kl) if ("occ" in k and "count" in k)]
# ============================================================
# 1) EPISODE INDEX
# ============================================================
class EpisodeIndex:
def __init__(self, npz_paths: List[str]):
self.paths = list(npz_paths)
self.T: List[int] = []
self.returns_energy: List[float] = []
self.returns_comfort: List[float] = []
self.s_meta: List[List[Tuple[int,int,int]]] = []
self.a_meta: List[List[Tuple[int,int,int]]] = []
self.state_keys: List[List[str]] = []
self.action_keys: List[List[str]] = []
self.keep_indices_map: List[List[int]] = []
self.comfort_idx: List[List[int]] = []
sig = dataset_signature(self.paths)
if os.path.exists(INDEX_CACHE_PATH):
try:
with open(INDEX_CACHE_PATH, "r") as f:
cache = json.load(f)
if cache.get("signature") == sig and "returns_energy" in cache:
print(f"[DataLoader] Loading cached index: {INDEX_CACHE_PATH}")
self.T = cache["T"]
self.returns_energy = cache["returns_energy"]
self.returns_comfort = cache["returns_comfort"]
self.state_keys = cache["state_keys"]
self.action_keys = cache["action_keys"]
self.keep_indices_map = cache.get("keep_indices_map", [])
self.s_meta = [[parse_feature_identity(k, is_action=False) for k in ks] for ks in self.state_keys]
self.a_meta = [[parse_feature_identity(k, is_action=True) for k in ks] for ks in self.action_keys]
if "comfort_idx" in cache:
self.comfort_idx = cache["comfort_idx"]
else:
print("[DataLoader] Cache missing comfort_idx. Rebuilding.")
raise ValueError("Outdated Cache")
print(f"[DataLoader] Cache loaded. Episodes indexed: {len(self.T)}")
return
else:
print("[DataLoader] Cache signature mismatch")
except Exception as e:
print(f"[DataLoader] Failed load cache: {e}")
for p in tqdm(self.paths, desc="Indexing"):
try:
with np.load(p, allow_pickle=True) as d:
obs = d["observations"]
if "rewards_energy" in d:
r_e = d["rewards_energy"]
r_c = d["rewards_comfort"]
else:
r_e = d["rewards"]
r_c = np.zeros_like(r_e)
ret_e = float(np.sum(r_e))
ret_c = float(np.sum(r_c))
T = int(obs.shape[0])
# Get RAW keys
raw_s_keys = d["state_keys"].astype(object).tolist() if "state_keys" in d else []
a_keys = d["action_keys"].astype(object).tolist() if "action_keys" in d else []
raw_s_keys = list(map(str, raw_s_keys))
a_keys = list(map(str, a_keys))
c_idx = compute_comfort_indices_from_state_keys(raw_s_keys)
keep_idxs = [i for i, k in enumerate(raw_s_keys) if k not in DROP_OBS_KEYS]
s_keys = [raw_s_keys[i] for i in keep_idxs]
s_meta = [parse_feature_identity(k, is_action=False) for k in s_keys]
a_meta = [parse_feature_identity(k, is_action=True) for k in a_keys]
self.T.append(T)
self.returns_energy.append(ret_e)
self.returns_comfort.append(ret_c)
self.state_keys.append(s_keys)
self.action_keys.append(a_keys)
self.comfort_idx.append(c_idx) # Save indices relative to RAW array
self.s_meta.append(s_meta)
self.a_meta.append(a_meta)
self.keep_indices_map.append(keep_idxs)
except Exception as e:
print(f"[IndexError] {p}: {e}")
# Save Cache
try:
cache = {
"signature": sig,
"T": self.T,
"returns_energy": self.returns_energy,
"returns_comfort": self.returns_comfort,
"state_keys": self.state_keys,
"action_keys": self.action_keys,
"keep_indices_map": self.keep_indices_map,
"comfort_idx": self.comfort_idx, # Added
}
with open(INDEX_CACHE_PATH, "w") as f:
json.dump(cache, f)
print(f"[DataLoader] Saved index cache: {INDEX_CACHE_PATH}")
except Exception as e:
print(f"[DataLoader] Warning: failed to save cache: {e}")
def __len__(self):
return len(self.T)
# ============================================================
# 2) NORMALIZATION
# ============================================================
def compute_and_save_norm_stats(npz_paths: List[str], index: "EpisodeIndex", max_episodes: int = 1000, stride: int = 4):
rng = np.random.default_rng(SEED)
n = len(index)
if n == 0:
raise RuntimeError("EpisodeIndex is empty (no valid episodes).")
k = min(max_episodes, n)
eps_idx = rng.choice(np.arange(n), size=k, replace=False)
obs_sum, obs_sumsq = None, None
act_sum, act_sumsq = None, None
count = 0
for ei in tqdm(eps_idx, desc="Computing norm stats"):
p = index.paths[int(ei)]
with np.load(p, allow_pickle=True) as d:
obs = d["observations"].astype(np.float32)
act = d["actions"].astype(np.float32)
keep_idxs = index.keep_indices_map[int(ei)]
obs = obs[:, keep_idxs]
obs = obs[::stride]
act = act[::stride]
if obs_sum is None:
obs_sum = np.zeros(obs.shape[1], dtype=np.float64)
obs_sumsq = np.zeros(obs.shape[1], dtype=np.float64)
act_sum = np.zeros(act.shape[1], dtype=np.float64)
act_sumsq = np.zeros(act.shape[1], dtype=np.float64)
obs_sum += obs.sum(axis=0)
obs_sumsq += (obs**2).sum(axis=0)
act_sum += act.sum(axis=0)
act_sumsq += (act**2).sum(axis=0)
count += obs.shape[0]
if obs_sum is None or obs_sumsq is None or act_sum is None or act_sumsq is None:
raise ValueError("obs_sum, obs_sumsq, act_sum, or act_sumsq is not initialized properly.")
obs_mean = (obs_sum / max(count, 1)).astype(np.float32)
obs_std = np.sqrt(np.maximum((obs_sumsq / max(count, 1)) - obs_mean**2, 1e-6)).astype(np.float32)
act_mean = (act_sum / max(count, 1)).astype(np.float32)
act_std = np.sqrt(np.maximum((act_sumsq / max(count, 1)) - act_mean**2, 1e-6)).astype(np.float32)
all_re = np.abs(np.array(index.returns_energy))
all_rc = np.abs(np.array(index.returns_comfort))
scale_energy = float(np.percentile(all_re, 95)) if len(all_re) > 0 else 1.0
scale_comfort = float(np.percentile(all_rc, 95)) if len(all_rc) > 0 else 1.0
scale_energy = max(scale_energy, 1.0)
scale_comfort = max(scale_comfort, 1.0)
np.savez_compressed(
NORM_CACHE_PATH,
obs_mean=obs_mean, obs_std=obs_std,
act_mean=act_mean, act_std=act_std,
scale_energy=np.array([scale_energy], dtype=np.float32),
scale_comfort=np.array([scale_comfort], dtype=np.float32),
)
class GeneralistDataset(Dataset):
def __init__(
self,
npz_paths: List[str],
max_tokens: int = MAX_TOKENS_PER_STEP,
seed: int = SEED,
virtual_len: int = 60_000,
gamma_rtg: float = 1.0,
topk_frac: Optional[float] = None,
topk_mode: Optional[str] = None,
topk_on: Optional[str] = None,
):
self.index = EpisodeIndex(npz_paths)
self.max_tokens = int(max_tokens)
self.seed = int(seed)
self.virtual_len = int(virtual_len)
self.epoch = 0
self.gamma_rtg = float(gamma_rtg)
self.is_train = True
self.all_eps = np.arange(len(self.index), dtype=np.int64)
# ---------------- Top-K selection ----------------
self.use_topk = bool(USE_TOPK) if topk_frac is None else True
self.topk_frac = float(TOPK_FRAC) if topk_frac is None else float(topk_frac)
self.topk_mode = str(TOPK_MODE) if topk_mode is None else str(topk_mode)
self.topk_on = str(TOPK_ON) if topk_on is None else str(topk_on)
rets_e = np.asarray(self.index.returns_energy, dtype=np.float32)
rets_c = np.asarray(self.index.returns_comfort, dtype=np.float32)
self.sel_eps = self.all_eps
self.weights = None
if self.use_topk and len(self.all_eps) > 0:
total_k = max(1, int(round(self.topk_frac * len(self.all_eps))))
# === STRATEGY 1: PARETO UNION (Energy + Comfort + Mixed) ===
if self.topk_on == "pareto":
print("[Top-K] Strategy: Energy + Comfort + Mixed")
k_part = max(1, total_k // 3)
# 1. Best Energy
idx_energy = np.argsort(rets_e)[::-1][:k_part]
# 2. Best Comfort
idx_comfort = np.argsort(rets_c)[::-1][:k_part]
# 3. Best Mixed (Balanced)
norm_e = (rets_e - rets_e.mean()) / (rets_e.std() + 1e-6)
norm_c = (rets_c - rets_c.mean()) / (rets_c.std() + 1e-6)
idx_mixed = np.argsort(norm_e + norm_c)[::-1][:k_part]
# Combine unique indices
top_eps = np.unique(np.concatenate([idx_energy, idx_comfort, idx_mixed]))
else:
if self.topk_on == "energy": rank_signal = rets_e
elif self.topk_on == "comfort": rank_signal = rets_c
elif self.topk_on == "mixed": rank_signal = rets_e + rets_c
else: rank_signal = rets_e # Fallback
order = np.argsort(rank_signal)[::-1]
top_eps = order[:total_k]
# === APPLY FILTER ===
if self.topk_mode == "filter":
self.sel_eps = top_eps
self.weights = None
elif self.topk_mode == "weighted":
self.sel_eps = top_eps
self.weights = None
# Load Norm Stats
if USE_NORMALIZATION:
if not os.path.exists(NORM_CACHE_PATH):
print("[DataLoader] Computing Norm Stats...")
compute_and_save_norm_stats(npz_paths, self.index)
z = np.load(NORM_CACHE_PATH)
self.obs_mean = z["obs_mean"].astype(np.float32)
self.obs_std = z["obs_std"].astype(np.float32)
self.act_mean = z["act_mean"].astype(np.float32)
self.act_std = z["act_std"].astype(np.float32)
self.scale_energy = float(z["scale_energy"][0])
self.scale_comfort = float(z["scale_comfort"][0])
else:
self.obs_mean = None
self.scale_energy = 1.0
self.scale_comfort = 1.0
def set_epoch(self, e: int):
self.epoch = int(e)
def __len__(self):
return self.virtual_len
def __getitem__(self, i: int) -> Dict[str, Any]:
x = _mix_u64(self.seed ^ (self.epoch * 0x9E3779B97F4A7C15) ^ (int(i) * 0xD1B54A32D192ED03))
# Preference sampling
if PREF_MODE == "fixed":
lam = float(PREF_FIXED_LAMBDA)
else:
rng = np.random.default_rng(int(x & 0xFFFFFFFF))
lam = float(rng.beta(PREF_BETA_A, PREF_BETA_B))
if self.weights is None:
ep_i = int(self.sel_eps[x % len(self.sel_eps)])
else:
u = ((x & 0xFFFFFFFF) / 2**32)
#Clip index to avoid out-of-bounds
cdf = np.cumsum(self.weights)
idx = int(np.searchsorted(cdf, u, side="right"))
idx = min(idx, len(self.weights) - 1)
ep_i = int(self.sel_eps[idx])
p = self.index.paths[ep_i]
T_total = int(self.index.T[ep_i])
L = CONTEXT_LEN
# 1. Load Data
with np.load(p, allow_pickle=True) as d:
raw_obs = d["observations"].astype(np.float32)
at = d["actions"].astype(np.float32)
if "rewards_energy" in d:
re = d["rewards_energy"].astype(np.float32)
rc = d["rewards_comfort"].astype(np.float32)
else:
re = d["rewards"].astype(np.float32)
rc = np.zeros_like(re)
if T_total >= L:
total_r = re + rc
num_candidates = 20
candidates = np.random.randint(0, T_total - L, size=num_candidates)
scores = np.array([total_r[c : c + L].sum() for c in candidates])
scores_stab = (scores - np.max(scores)) / (np.std(scores) + 1e-6)
probs = np.exp(scores_stab)
probs /= probs.sum()
s0 = np.random.choice(candidates, p=probs)
else:
s0 = 0
cidx = self.index.comfort_idx[ep_i]
if len(cidx) > 0:
ash55_raw_slice = raw_obs[:, cidx]
else:
ash55_raw_slice = np.zeros((T_total, 1), dtype=np.float32)
keep_idxs = self.index.keep_indices_map[ep_i]
st = raw_obs[:, keep_idxs]
s_keys_ep = self.index.state_keys[ep_i]
def find_idx(substring):
for idx, k in enumerate(s_keys_ep):
if substring in k.lower(): return idx
return -1
idx_out = find_idx("outdoor_temp")
idx_dew = find_idx("dewpoint")
idx_hr = find_idx("hour")
idx_mth = find_idx("month")
idx_occ = compute_occupancy_indices_from_state_keys(s_keys_ep)
def get_window(arr, pad_val=0.0):
if T_total >= L:
return arr[s0:s0+L]
else:
out = np.full((L, *arr.shape[1:]), pad_val, dtype=np.float32)
out[:T_total] = arr
return out
st_win = get_window(st)
at_win = get_window(at)
at_win_raw = at_win.copy()
re_win = get_window(re)
rc_win = get_window(rc)
ash55_win = get_window(ash55_raw_slice)
ash55_any = ash55_win.mean(axis=1).astype(np.float32)
tm_win = np.zeros((L,), dtype=np.float32)
valid_len = min(T_total, L)
tm_win[:valid_len] = 1.0
valid_mask = (tm_win > 0.5)
FORECAST_STEPS = 48
future_start = s0 + L
future_end = min(T_total, future_start + FORECAST_STEPS)
forecast_temp = 0.0
if idx_out != -1:
current_vals = st_win[valid_mask, idx_out]
if len(current_vals) > 0:
forecast_temp = current_vals.mean()
if future_end > future_start:
future_vals = st[future_start:future_end, idx_out]
if len(future_vals) > 0:
forecast_temp = future_vals.mean()
# 3. Context Vector
t_mean, t_std = 0.0, 0.0
if idx_out != -1 and valid_mask.sum() > 0:
vals = st_win[valid_mask, idx_out]
t_mean, t_std = vals.mean(), vals.std()
d_mean = 0.0
if idx_dew != -1 and valid_mask.sum() > 0:
d_mean = st_win[valid_mask, idx_dew].mean()
occ_frac = 0.0
if len(idx_occ) > 0 and valid_mask.sum() > 0:
occ_sum = st_win[valid_mask][:, idx_occ].sum(axis=1)
occ_frac = (occ_sum > 0.5).mean()
# Cyclical Time
hr_sin, hr_cos = 0.0, 0.0
if idx_hr != -1 and valid_mask.sum() > 0:
hr_val = st_win[valid_mask, idx_hr][0]
hr_sin = np.sin(2 * np.pi * hr_val / 24.0)
hr_cos = np.cos(2 * np.pi * hr_val / 24.0)
mth_sin, mth_cos = 0.0, 0.0
if idx_mth != -1 and valid_mask.sum() > 0:
mth_val = st_win[valid_mask, idx_mth][0]
mth_sin = np.sin(2 * np.pi * mth_val / 12.0)
mth_cos = np.cos(2 * np.pi * mth_val / 12.0)
ctx_vec = np.array([
t_mean, t_std, d_mean, occ_frac,
hr_sin, hr_cos, mth_sin, mth_cos,
forecast_temp,
0.0
], dtype=np.float32)
next_st_win = np.zeros_like(st_win)
future_4h_st_win = np.zeros_like(st_win)
if T_total >= L:
end_idx = min(s0 + L + 1, T_total)
actual_len = end_idx - (s0 + 1)
if actual_len > 0:
next_st_win[:actual_len] = st[s0+1 : end_idx]
f_end_idx = min(s0 + L + PHYSICS_HORIZON, T_total)
f_actual_len = f_end_idx - (s0 + PHYSICS_HORIZON)
if f_actual_len > 0:
future_4h_st_win[:f_actual_len] = st[s0 + PHYSICS_HORIZON : f_end_idx]
else:
if T_total > 1:
next_st_win[:T_total-1] = st[1:T_total]
if USE_NORMALIZATION and (self.obs_mean is not None):
st_win = (st_win - self.obs_mean) / self.obs_std
next_st_win = (next_st_win - self.obs_mean) / self.obs_std
future_4h_st_win = (future_4h_st_win - self.obs_mean) / self.obs_std
at_win = (at_win - self.act_mean) / self.act_std
delta_4h_win = future_4h_st_win - st_win
full_rtg_e = discounted_cumsum(re, gamma=self.gamma_rtg)
full_rtg_c = discounted_cumsum(rc, gamma=self.gamma_rtg)
rtg_e_win = get_window(full_rtg_e)
rtg_c_win = get_window(full_rtg_c)
rtg_e_norm = rtg_e_win / self.scale_energy
rtg_c_norm = rtg_c_win / self.scale_comfort
rtg_combined = np.stack([rtg_e_norm, rtg_c_norm], axis=-1)
if getattr(self, "is_train", True):
rtg_combined += np.random.normal(0, 0.005, rtg_combined.shape).astype(np.float32)
feat_ids = np.full((L, self.max_tokens), PAD_ID, dtype=np.int64)
feat_vals = np.zeros((L, self.max_tokens), dtype=np.float32)
zone_ids = np.zeros((L, self.max_tokens), dtype=np.int64)
attn_mask = np.zeros((L, self.max_tokens), dtype=np.int64)
target_toks = np.full((L, self.max_tokens), -100, dtype=np.int64)
target_mask = np.zeros((L, self.max_tokens), dtype=np.float32)
s_meta = self.index.s_meta[ep_i]
a_meta = self.index.a_meta[ep_i]
S_dim = min(len(s_meta), st_win.shape[1])
A_dim = min(len(a_meta), at_win.shape[1])
num_act_toks = min(A_dim, self.max_tokens)
num_state_toks = min(S_dim, self.max_tokens - num_act_toks)
if num_state_toks > 0:
feat_ids[:, :num_state_toks] = [m[0] for m in s_meta[:num_state_toks]]
zone_ids[:, :num_state_toks] = [m[1] for m in s_meta[:num_state_toks]]
feat_vals[:, :num_state_toks] = st_win[:, :num_state_toks]
attn_mask[:, :num_state_toks] = 1
if num_act_toks > 0:
start = num_state_toks
end = start + num_act_toks
feat_ids[:, start:end] = [m[0] for m in a_meta[:num_act_toks]]
zone_ids[:, start:end] = [m[1] for m in a_meta[:num_act_toks]]
attn_mask[:, start:end] = 1
a_in = np.zeros((L, num_act_toks), dtype=np.float32)
if L > 1:
a_in[1:] = at_win[:-1, :num_act_toks]
feat_vals[:, start:end] = a_in
a_keys = self.index.action_keys[ep_i]
at_discrete = discretize_actions_to_bins(at_win_raw, a_keys)
target_toks[:, start:end] = at_discrete[:, :num_act_toks]
target_mask[:, start:end] = 1.0
valid_t = (tm_win > 0.5)[:, None]
attn_mask *= valid_t.astype(np.int64)
target_mask *= valid_t
return {
"feature_ids": feat_ids,
"feature_values": feat_vals,
"zone_ids": zone_ids,
"attention_mask": attn_mask,
"target_action_tokens": target_toks,
"target_mask": target_mask,
"rtg": rtg_combined,
"rtg_energy": rtg_e_norm,
"rtg_comfort": rtg_c_norm,
"rewards_energy": re_win,
"rewards_comfort": rc_win,
"pref_lambda": np.float32(lam),
"ash55_any": ash55_any,
"next_obs": next_st_win,
"target_4h_delta": delta_4h_win,
"time_mask": tm_win,
"context": ctx_vec,
}
def generalist_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
def stack(k):
return np.stack([b[k] for b in batch])
return {
"feature_ids": torch.from_numpy(stack("feature_ids")).long(),
"feature_values": torch.from_numpy(stack("feature_values")).float(),
"zone_ids": torch.from_numpy(stack("zone_ids")).long(),
"attention_mask": torch.from_numpy(stack("attention_mask")).long(),
"target_action_tokens": torch.from_numpy(stack("target_action_tokens")).long(),
"target_mask": torch.from_numpy(stack("target_mask")).float(),
"rtg": torch.from_numpy(stack("rtg")).float(),
"rtg_energy": torch.from_numpy(stack("rtg_energy")).float(),
"rtg_comfort": torch.from_numpy(stack("rtg_comfort")).float(),
"rewards_energy": torch.from_numpy(stack("rewards_energy")).float(),
"rewards_comfort": torch.from_numpy(stack("rewards_comfort")).float(),
"pref_lambda": torch.from_numpy(stack("pref_lambda")).float(),
"ash55_any": torch.from_numpy(stack("ash55_any")).float(),
"next_obs": torch.from_numpy(stack("next_obs")).float(),
"target_4h_delta": torch.from_numpy(stack("target_4h_delta")).float(),
"time_mask": torch.from_numpy(stack("time_mask")).float(),
"context": torch.from_numpy(stack("context")).float(),
}
# ============================================================
# 4) DEBUG MAIN
# ============================================================
def main():
npz_paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True))
npz_paths = [p for p in npz_paths if os.path.basename(p) not in ("norm_stats.npz",)]
if not npz_paths:
print(f"No data found in {DATA_DIR}")
return
ds = GeneralistDataset(npz_paths, max_tokens=64)
loader = DataLoader(ds, batch_size=4, collate_fn=generalist_collate_fn, num_workers=0)
batch = next(iter(loader))
if __name__ == "__main__":
main()