multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
6e4b62e verified
Raw
History Blame Contribute Delete
19 kB
import logging
import os
from typing import Optional
import torch
import torch.distributed as dist
from accelerate import Accelerator
from tqdm import tqdm
from src.model_training.transformers_compat import patch_transformers_hybrid_cache
patch_transformers_hybrid_cache()
from diffsynth.trainers.utils import DiffusionTrainingModule
from src.model_training.fov_retrieval import FOVMemoryRetriever
from src.model_training.fov_retrieval import retrieve_context_frames_advanced, retrieve_fov_context_frames
from src.model_training.training_modules.model_logger import ModelLogger
logger = logging.getLogger(__name__)
def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
per_device_train_batch_size: int = 1,
seed: int = 42,
spike_threshold: float = 5.0,
resume_step_count: int = 0,
enable_fov_retrieval: bool = False,
retrieval_method: str = "fov", # fov | latent_sim
latent_retrieval_dir: Optional[str] = None,
dataset_base_path: str = None,
fov_retriever: Optional[FOVMemoryRetriever] = None,
context_memory_frames: int = 5,
prev_chunk_frames: int = 81,
fov_top_k: int = 4, # Number of overlap frames to retrieve. GT frame 0 will be added automatically.
use_rt_relative: bool = False, # Experiment 1_4_2: Use RT relative conversion (aligned with Context-as-Memory)
strict_overlap_context: bool = False,
dataset_repeat: int = 1, # Add dataset_repeat parameter for step calculation
use_camera_encoder: bool = False, # exp1_4_3: use CameraEncoder (action_mlp unused -> need find_unused_parameters)
num_workers: int = 0, # DataLoader workers: 0=main process, >0=parallel preload (recommend 4 for video)
context_source: str = "fov",
max_train_steps: int = 0,
progress_total_steps: int = 0,
):
prev_chunk_frames = int(prev_chunk_frames)
# VideoDataset can return None when file loading fails; keep distributed batches aligned.
def collate_fn(batch):
valid_batch = [item for item in batch if item is not None]
return valid_batch or None
num_workers = max(0, int(num_workers))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=per_device_train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=num_workers,
drop_last=True,
persistent_workers=(num_workers > 0),
pin_memory=(num_workers > 0 and torch.cuda.is_available()),
)
if num_workers > 0:
logger.info(f"[DataLoader] num_workers={num_workers}, persistent_workers=True, pin_memory={torch.cuda.is_available()} (data preload parallel to GPU)")
timeout_seconds = int(os.environ.get('TORCH_DISTRIBUTED_DEFAULT_TIMEOUT', 2400))
os.environ['TORCH_DISTRIBUTED_DEFAULT_TIMEOUT'] = str(timeout_seconds)
logger.info(f"[Timeout Config] Setting TORCH_DISTRIBUTED_DEFAULT_TIMEOUT={timeout_seconds} seconds ({timeout_seconds/60:.1f} minutes)")
# Conditional context paths can leave parameters unused on some iterations.
need_find_unused = bool(use_camera_encoder) or model_logger.context_drop_prob > 0.0
if need_find_unused:
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_kwargs])
logger.info("[DDP] find_unused_parameters=True (conditional modules / context_drop_prob enabled)")
else:
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
if model_logger.enable_video_sampling and model_logger.total_steps is not None:
dataset_size = len(dataset)
num_processes = accelerator.num_processes
effective_dataset_size = dataset_size * dataset_repeat
total_steps_per_gpu = (effective_dataset_size * num_epochs) // (gradient_accumulation_steps * num_processes * per_device_train_batch_size)
total_steps_global = total_steps_per_gpu * num_processes
model_logger.total_steps = total_steps_global
if accelerator.is_main_process:
logger.info("="*80)
logger.info("[Step Calculation] Corrected total_steps after accelerator.init")
logger.info("="*80)
logger.info(f" Dataset size (unique samples): {dataset_size}")
logger.info(f" Dataset repeat: {dataset_repeat}")
logger.info(f" Effective dataset size: {effective_dataset_size} (unique * repeat)")
logger.info(f" Number of epochs: {num_epochs}")
logger.info(f" Number of GPUs: {num_processes}")
logger.info(f" Gradient accumulation steps: {gradient_accumulation_steps}")
logger.info(f" Per-device batch size: {per_device_train_batch_size}")
logger.info(f" Total samples to process: {effective_dataset_size * num_epochs}")
logger.info(f" Steps per GPU: ~{total_steps_per_gpu}")
logger.info(f" Total steps (global): {total_steps_global}")
logger.info("")
logger.info(f" ✓ Each GPU will process ~{total_steps_per_gpu} steps")
logger.info(f" ✓ This ensures traversal of all {effective_dataset_size} samples")
logger.info(f" ({dataset_size} unique samples × {dataset_repeat} repeats)")
logger.info(f" ✓ Over {num_epochs} epoch(s)")
logger.info("="*80)
step = resume_step_count
traj_loss = 0.0
if resume_step_count > 0:
adaptation_steps = max(200, resume_step_count // 100)
spike_detection_start_step = resume_step_count + adaptation_steps
logger.info(f"Resuming from step {resume_step_count}, spike detection will start at step {spike_detection_start_step} (after {adaptation_steps} adaptation steps)")
else:
spike_detection_start_step = 100
for epoch_id in range(num_epochs):
epoch_seed = seed + epoch_id
torch.manual_seed(epoch_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(epoch_seed)
torch.cuda.manual_seed_all(epoch_seed)
if resume_step_count > 0 and epoch_id == 0:
estimated_skip = resume_step_count // gradient_accumulation_steps
if estimated_skip > 0:
logger.info(f"Skipping {estimated_skip} data samples to resume from step {resume_step_count}...")
dataloader_iter = iter(dataloader)
for _ in tqdm(range(estimated_skip), desc="Skipping data", unit="samples", leave=False):
try:
next(dataloader_iter)
except StopIteration:
break
dataloader = dataloader_iter
logger.info(f"Successfully skipped {estimated_skip} data samples, resuming training...")
# Track consecutive None data to detect if we're stuck in a loop
consecutive_none_count = 0
max_consecutive_none = 100 # If we get 100 consecutive None values, something is wrong
progress_total = int(progress_total_steps)
if progress_total <= 0:
progress_total = len(dataloader)
progress_bar = tqdm(
dataloader,
total=progress_total,
initial=resume_step_count if progress_total_steps else 0,
desc="Training steps",
unit="step",
)
for data_idx, data in enumerate(progress_bar):
# Handle None data (can happen if all files in batch fail to load)
if data is None:
consecutive_none_count += 1
if consecutive_none_count >= max_consecutive_none:
logger.error(f"Received {max_consecutive_none} consecutive None data samples. This suggests a serious dataset issue. Stopping training.")
raise ValueError(f"Too many consecutive None data samples ({max_consecutive_none}). Check dataset files.")
# Log warning but continue (will skip this step)
if consecutive_none_count <= 10 or consecutive_none_count % 10 == 0:
logger.warning(f"Received None data at index {data_idx} (consecutive: {consecutive_none_count}). This may indicate missing or corrupted files. Skipping...")
# Still increment step to keep step_count synchronized
step += 1
dummy_loss = torch.tensor(0.0, device=accelerator.device, requires_grad=False)
model_logger.on_step_end(dummy_loss, accelerator, model, current_batch=samples)
continue
# Reset consecutive None counter when we get valid data
consecutive_none_count = 0
# Normalize to list of samples for batch processing (per_device_train_batch_size > 1)
samples = data if isinstance(data, list) else [data]
# Simplified context-based retrieval OR replay/prev_chunk_tail (aligned with multichunk eval)
context_retrieval_success = True # Set False if any sample fails (for strict mode)
_umodel = accelerator.unwrap_model(model)
_cm_frames = int(_umodel.context_memory_frames)
_cs = context_source.strip().lower()
if _cs not in ("fov", "replay", "prev_chunk_tail"):
_cs = "fov"
if _cs == "replay" and dataset_base_path:
from src.model_training.multichunk_sample_utils import (
replay_context_actions_from_segment_actions,
replay_context_global_indices,
synthetic_replay_context_from_segment,
)
for d in samples:
vf = d.get("video") or []
n_seg = min(int(prev_chunk_frames), len(vf)) if vf else 0
ctx_pil = synthetic_replay_context_from_segment(vf, n_seg, _cm_frames) if n_seg > 0 else None
if not ctx_pil:
context_retrieval_success = False
break
d["context_frames"] = ctx_pil
d["context_source"] = "replay_synthetic"
acts = d.get("actions")
if isinstance(acts, list) and len(acts) >= n_seg:
ra = replay_context_actions_from_segment_actions(acts[:n_seg], n_seg, _cm_frames)
if ra is not None:
d["context_actions"] = ra
sf = int(d.get("start_frame", 0) or 0)
idxs = replay_context_global_indices(n_seg, _cm_frames)
d["context_frame_indices"] = [sf + int(i) for i in idxs]
elif _cs == "prev_chunk_tail" and dataset_base_path:
from src.model_training.multichunk_sample_utils import load_prev_chunk_tail_from_disk, load_prev_chunk_tail_rt_actions
_ctx_pos = os.environ.get("CONTEXT_POSITION", "suffix").strip().lower()
_nearest_first = (_ctx_pos == "suffix")
for d in samples:
sf = int(d.get("start_frame", 0) or 0)
vn = d.get("video_name", "")
pil_list, idxs = load_prev_chunk_tail_from_disk(
dataset_base_path, str(vn), sf, _cm_frames, nearest_first=_nearest_first
)
if not pil_list:
context_retrieval_success = False
break
d["context_frames"] = pil_list
d["context_frame_indices"] = list(idxs) if idxs else []
d["context_source"] = "prev_chunk_tail"
ra, _ = load_prev_chunk_tail_rt_actions(
dataset_base_path,
str(vn),
sf,
_cm_frames,
use_rt_relative=use_rt_relative,
nearest_first=_nearest_first,
)
if ra:
d["context_actions"] = ra
elif enable_fov_retrieval and dataset_base_path:
for d in samples:
if retrieval_method == "latent_sim":
(
context_frames,
context_actions,
context_indices,
ref_frame_idx,
video_name,
source,
) = retrieve_context_frames_advanced(
data=d,
dataset_base_path=dataset_base_path,
top_k=fov_top_k,
drop_overlap_probability=0.1,
use_rt_relative=use_rt_relative,
retrieval_method="latent_sim",
latent_retrieval_dir=latent_retrieval_dir,
strict_overlap_labels=strict_overlap_context,
)
else:
(
context_frames,
context_actions,
context_indices,
ref_frame_idx,
video_name,
source,
) = retrieve_fov_context_frames(
data=d,
dataset_base_path=dataset_base_path,
fov_retriever=fov_retriever, # unused in simplified retrieval, kept for compat
top_k=fov_top_k, # fov_top_k is number of overlap frames (4), GT frame 0 will be added automatically
use_precomputed_overlaps=True,
strict_overlap_labels=strict_overlap_context,
allow_realtime_fallback=(not strict_overlap_context),
allow_segment_fallback=(not strict_overlap_context),
)
if context_frames and len(context_frames) > 0:
# Use retrieved frames as context
d["context_frames"] = context_frames
if context_actions:
d["context_actions"] = context_actions
# Store retrieval metadata for visualization/debugging
d["context_frame_indices"] = context_indices
d["context_ref_frame_idx"] = ref_frame_idx
d["context_video_name"] = video_name
d["context_source"] = source
else:
context_retrieval_success = False
break
# Strict mode: if we require context but retrieval failed, skip this step
_need_ctx_strict = (
strict_overlap_context
and (not context_retrieval_success)
and (
enable_fov_retrieval
or context_source.strip().lower() in ("replay", "prev_chunk_tail")
)
)
if _need_ctx_strict:
if step % 50 == 0 and accelerator.is_main_process:
logger.warning(f"[CONTEXT][STRICT] No context at step={step}, skipping this training sample.")
step += 1
dummy_loss = torch.tensor(0.0, device=accelerator.device, requires_grad=False)
model_logger.on_step_end(dummy_loss, accelerator, model, current_batch=samples)
continue
with accelerator.accumulate(model):
optimizer.zero_grad()
# One forward over full batch: data is list of B dicts when per_device_train_batch_size > 1
# Main loss on current batch
loss = model(data)
step += 1
if traj_loss == 0.0:
traj_loss = loss.item()
else:
alpha = 0.01
traj_loss = (1 - alpha) * traj_loss + alpha * loss.item()
if step >= spike_detection_start_step and traj_loss > 0:
relative_loss = loss.item() / traj_loss
if resume_step_count > 0 and step < resume_step_count + 500:
effective_threshold = spike_threshold * 1.5
else:
effective_threshold = spike_threshold
should_skip = relative_loss > effective_threshold
# Keep the skip decision identical across ranks to avoid DDP hangs.
skip_t = torch.tensor(1.0 if should_skip else 0.0, device=accelerator.device, dtype=torch.float32)
if accelerator.num_processes > 1:
dist.all_reduce(skip_t, op=dist.ReduceOp.MAX)
skip_global = skip_t.item() > 0.5
if skip_global:
if accelerator.is_main_process:
logger.warning(f"Spike detected at step {step} (loss={loss.item():.4f}, traj_loss={traj_loss:.4f}, ratio={relative_loss:.2f}), sync skip across all ranks")
dummy_loss = torch.tensor(0.0, device=accelerator.device, requires_grad=False)
model_logger.on_step_end(dummy_loss, accelerator, model, current_batch=samples)
del loss
torch.cuda.empty_cache()
continue
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(loss, accelerator, model, current_batch=samples)
scheduler.step()
if max_train_steps and step >= max_train_steps:
if progress_total_steps:
progress_bar.n = min(step, progress_bar.total) if progress_bar.total is not None else step
progress_bar.refresh()
if accelerator.is_main_process:
logger.info(f"[TRAIN] Reached max_train_steps={max_train_steps}; stopping without epoch checkpoint.")
accelerator.wait_for_everyone()
return
model_logger.on_epoch_end(accelerator, model, epoch_id)