| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Main entry point to run the experiments. Contains general setup and the proper training code. |
| | """ |
| |
|
| | import argparse |
| | import datetime as dt |
| | import json |
| | import os |
| | import random |
| | import sys |
| | import textwrap |
| | import time |
| | from collections.abc import Callable |
| | from contextlib import AbstractContextManager, nullcontext |
| | from functools import partial |
| | from typing import Any, Literal, Optional |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.amp import GradScaler, autocast |
| | from tqdm import tqdm |
| | from transformers import GenerationConfig, set_seed |
| | from utils import ( |
| | FILE_NAME_TRAIN_PARAMS, |
| | BucketIterator, |
| | TrainResult, |
| | TrainStatus, |
| | get_accuracy, |
| | get_base_model_info, |
| | get_dataset_info, |
| | get_file_size, |
| | get_model, |
| | get_optimizer_and_scheduler, |
| | get_peft_branch, |
| | get_tokenizer, |
| | get_train_config, |
| | init_accelerator, |
| | log_results, |
| | validate_experiment_path, |
| | ) |
| |
|
| | from data import get_train_valid_test_datasets, get_wiki_small |
| | from peft import AdaLoraConfig, PeftConfig |
| | from peft.utils import CONFIG_NAME, infer_device |
| |
|
| |
|
| | |
| | BUCKET_FACTOR = 20 |
| | |
| | ACCELERATOR_EMPTY_CACHE_SCHEDULE = 10 |
| |
|
| | |
| | os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" |
| |
|
| |
|
| | def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig: |
| | |
| | generate_kwargs = {k: v for k, v in generate_kwargs.items() if v is not None} |
| | if ("max_length" in generate_kwargs) and ("max_new_tokens" in generate_kwargs): |
| | |
| | |
| | new_max_length = min(generate_kwargs["max_new_tokens"] + seq_len, generate_kwargs["max_length"]) |
| | del generate_kwargs["max_new_tokens"] |
| | generate_kwargs["max_length"] = new_max_length |
| | generation_config = GenerationConfig(**generate_kwargs) |
| | return generation_config |
| |
|
| |
|
| | def evaluate(model, tokenizer, ds, batch_size, generate_kwargs, use_tqdm: bool = False) -> tuple[list[str], list[str]]: |
| | generate_kwargs = generate_kwargs.copy() |
| | generate_kwargs["pad_token_id"] = tokenizer.eos_token_id |
| | with torch.inference_mode(): |
| | predictions = [] |
| | responses = [] |
| | pbar = range(0, len(ds), batch_size) |
| | if use_tqdm: |
| | pbar = tqdm(pbar) |
| | for j in pbar: |
| | sliced = ds[j : j + batch_size] |
| | responses += sliced.pop("response") |
| | batch = tokenizer.pad(sliced, return_tensors="pt", padding_side="left").to(model.device) |
| | seq_len = batch["input_ids"].shape[1] |
| | generation_config = get_generation_config(seq_len=seq_len, generate_kwargs=generate_kwargs) |
| | outputs = model.generate(**batch, generation_config=generation_config) |
| | predictions += tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | return predictions, responses |
| |
|
| |
|
| | @torch.inference_mode |
| | def calculate_mean_per_token_loss(model, tokenizer, rows: list[str], batch_size: int, max_length: int) -> float: |
| | """Calculate the mean loss per token on the given dataset. |
| | |
| | Useful to determine general model performance before and after training to get an estimate of the magnitude of |
| | 'forgetting'. Note that for Wikipedia data, since the information density is quite high, the loss can be |
| | surprisingly large. |
| | |
| | """ |
| | losses: list[float] = [] |
| | for j in range(0, len(rows), batch_size): |
| | sliced = rows[j : j + batch_size] |
| | batch = tokenizer(sliced, truncation=True, max_length=max_length) |
| | batch = tokenizer.pad(batch, return_tensors="pt", padding_side="left").to(model.device) |
| | outputs = model(**batch, pad_token_id=tokenizer.eos_token_id) |
| | logits = outputs.logits |
| | for logit, target, mask in zip(logits, batch["input_ids"], batch["attention_mask"]): |
| | |
| | num_tokens = mask.sum() |
| | token_losses = torch.nn.functional.cross_entropy(logit[-num_tokens:], target[-num_tokens:], reduction="none") |
| | losses.extend(loss.item() for loss in token_losses) |
| | return torch.tensor(losses).mean().item() |
| |
|
| |
|
| | class DummyGradScaler: |
| | |
| | def scale(self, loss): |
| | return loss |
| |
|
| | def unscale_(self, optimizer): |
| | pass |
| |
|
| | def step(self, optimizer): |
| | optimizer.step() |
| |
|
| | def update(self): |
| | pass |
| |
|
| |
|
| | def train( |
| | *, |
| | model: nn.Module, |
| | max_steps: int, |
| | batch_size: int, |
| | batch_size_eval: int, |
| | tokenizer: Any, |
| | accelerator_memory_init: int, |
| | eval_steps: int, |
| | generation_kwargs: dict[str, Any], |
| | grad_norm_clip: float, |
| | optimizer_type: str, |
| | optimizer_kwargs: dict[str, Any], |
| | query_template: str, |
| | lr_scheduler_arg: Optional[Literal["cosine"]], |
| | use_amp: bool, |
| | is_adalora: bool, |
| | ) -> TrainResult: |
| | accelerator_memory_allocated_log = [] |
| | accelerator_memory_reserved_log = [] |
| | losses = [] |
| | durations = [] |
| | metrics = [] |
| | sample = 0 |
| | total_samples = 0 |
| | total_tokens = [] |
| |
|
| | device_type = infer_device() |
| | torch_accelerator_module = getattr(torch, device_type, torch.cuda) |
| | if use_amp: |
| | grad_scaler: GradScaler | DummyGradScaler = GradScaler(device=device_type) |
| | autocast_ctx: Callable[[], AbstractContextManager[Any]] = partial(autocast, device_type=device_type) |
| | else: |
| | grad_scaler = DummyGradScaler() |
| | autocast_ctx = nullcontext |
| |
|
| | optimizer, lr_scheduler = get_optimizer_and_scheduler( |
| | model, |
| | optimizer_type=optimizer_type, |
| | max_steps=max_steps, |
| | lr_scheduler_arg=lr_scheduler_arg, |
| | **optimizer_kwargs, |
| | ) |
| | |
| | if hasattr(model, "get_nb_trainable_parameters"): |
| | num_trainable_params, num_params = model.get_nb_trainable_parameters() |
| | else: |
| | num_params = model.num_parameters() |
| | num_trainable_params = num_params |
| | print_verbose( |
| | f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || " |
| | f"trainable: {100 * num_trainable_params / num_params:.4f}%" |
| | ) |
| |
|
| | status = TrainStatus.FAILED |
| | tic_train = time.perf_counter() |
| | eval_time = 0.0 |
| | error_msg = "" |
| |
|
| | rows_wiki = get_wiki_small() |
| | model.eval() |
| | |
| | wiki_loss_before = calculate_mean_per_token_loss( |
| | model=model, tokenizer=tokenizer, rows=rows_wiki, batch_size=batch_size, max_length=768 |
| | ) |
| | model.train() |
| |
|
| | ds_train, ds_valid, ds_test = get_train_valid_test_datasets( |
| | tokenizer=tokenizer, query_template=query_template, print_fn=print_verbose |
| | ) |
| | |
| | |
| | iterator_train = BucketIterator( |
| | ds_train, |
| | batch_size=batch_size, |
| | bucket_factor=BUCKET_FACTOR, |
| | delete_cols=["response"], |
| | ) |
| | try: |
| | pbar = tqdm(range(1, max_steps + 1)) |
| | for step, batch in zip(pbar, iterator_train): |
| | tic = time.perf_counter() |
| |
|
| | |
| | tokens_per_sample = [len(i) for i in batch["input_ids"]] |
| | total_tokens.append(sum(tokens_per_sample) + len(tokens_per_sample)) |
| | batch = tokenizer.pad(batch, return_tensors="pt").to(model.device) |
| | actual_batch_size = len(batch["input_ids"]) |
| | total_samples += actual_batch_size |
| | sample += batch_size |
| | if sample >= len(ds_train): |
| | sample = 0 |
| |
|
| | |
| | labels = batch["input_ids"].clone() |
| | |
| | |
| | |
| | |
| | for i, num_tokens in enumerate(tokens_per_sample): |
| | labels[i, num_tokens + 1 :] = -100 |
| | batch["labels"] = labels |
| | num_items_in_batch = batch["attention_mask"].sum().item() |
| |
|
| | |
| | optimizer.zero_grad() |
| | with autocast_ctx(): |
| | outputs = model(**batch, num_items_in_batch=num_items_in_batch) |
| | loss = outputs.loss |
| | grad_scaler.scale(loss).backward() |
| | if grad_norm_clip: |
| | grad_scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip) |
| | grad_scaler.step(optimizer) |
| | grad_scaler.update() |
| | lr_scheduler.step() |
| |
|
| | if is_adalora: |
| | model.base_model.update_and_allocate(step) |
| |
|
| | losses.append(loss.item()) |
| | pbar.set_postfix({"loss": loss.item()}) |
| | accelerator_memory_allocated_log.append( |
| | torch_accelerator_module.memory_allocated() - accelerator_memory_init |
| | ) |
| | accelerator_memory_reserved_log.append( |
| | torch_accelerator_module.memory_reserved() - accelerator_memory_init |
| | ) |
| | toc = time.perf_counter() |
| | durations.append(toc - tic) |
| |
|
| | |
| | if step % eval_steps == 0: |
| | tic_eval = time.perf_counter() |
| | loss_avg = sum(losses[-eval_steps:]) / eval_steps |
| | memory_allocated_avg = sum(accelerator_memory_allocated_log[-eval_steps:]) / eval_steps |
| | memory_reserved_avg = sum(accelerator_memory_reserved_log[-eval_steps:]) / eval_steps |
| | token_sum = sum(total_tokens[-eval_steps:]) |
| | dur_train = sum(durations[-eval_steps:]) |
| | tokens_per_sec = token_sum / dur_train |
| |
|
| | model.eval() |
| | predictions, responses = evaluate( |
| | model=model, |
| | tokenizer=tokenizer, |
| | ds=ds_valid, |
| | batch_size=batch_size_eval, |
| | generate_kwargs={**generation_kwargs}, |
| | ) |
| | model.train() |
| |
|
| | example = random.choice(predictions) |
| | example = textwrap.shorten(example, width=750) |
| | example = textwrap.indent(example, " ") |
| | print_verbose(f"\nExample prediction:\n{example}\n") |
| | accuracy = get_accuracy(predictions=predictions, responses=responses) |
| | num_tokens_generated = sum(sum(mask) for mask in tokenizer(predictions)["attention_mask"]) |
| |
|
| | toc_eval = time.perf_counter() |
| | dur_eval = toc_eval - tic_eval |
| | eval_time += toc_eval - tic_eval |
| | elapsed = time.perf_counter() - tic_train |
| |
|
| | metrics.append( |
| | { |
| | "step": step, |
| | "valid accuracy": accuracy, |
| | "train loss": loss_avg, |
| | "train samples": total_samples, |
| | "train time": dur_train, |
| | "eval time": dur_eval, |
| | "tokens / sec": tokens_per_sec, |
| | "mem allocated avg": memory_allocated_avg, |
| | "mem reserved avg": memory_reserved_avg, |
| | "elapsed time": elapsed, |
| | } |
| | ) |
| |
|
| | log_dict = { |
| | "step": f"{step:5d}", |
| | "samples": f"{total_samples:7d}", |
| | "lr": f"{lr_scheduler.get_last_lr()[0]:.2e}", |
| | "loss avg": f"{loss_avg:.4f}", |
| | "valid acc": f"{accuracy:.3f}", |
| | "gen valid tokens": num_tokens_generated, |
| | "train time": f"{dur_train:.1f}s", |
| | "eval time": f"{dur_eval:.1f}s", |
| | "train tokens / sec": f"{tokens_per_sec:.0f}", |
| | "mem allocated": f"{memory_allocated_avg:.0f}", |
| | "mem reserved": f"{memory_reserved_avg:.0f}", |
| | "elapsed time": f"{elapsed // 60:.0f}min {elapsed % 60:.0f}s", |
| | } |
| | print_verbose(json.dumps(log_dict)) |
| |
|
| | if step % ACCELERATOR_EMPTY_CACHE_SCHEDULE == 0: |
| | torch_accelerator_module.empty_cache() |
| |
|
| | print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.") |
| | |
| | model.eval() |
| | predictions, responses = evaluate( |
| | model=model, |
| | tokenizer=tokenizer, |
| | ds=ds_test, |
| | batch_size=batch_size_eval, |
| | generate_kwargs={**generation_kwargs, "pad_token_id": tokenizer.eos_token_id}, |
| | use_tqdm=len(ds_test) > 100, |
| | ) |
| | accuracy = get_accuracy(predictions=predictions, responses=responses) |
| | |
| | wiki_loss_after = calculate_mean_per_token_loss( |
| | model=model, tokenizer=tokenizer, rows=rows_wiki, batch_size=batch_size, max_length=768 |
| | ) |
| | forgetting = wiki_loss_after - wiki_loss_before |
| | metrics.append( |
| | { |
| | "step": step, |
| | "test accuracy": accuracy, |
| | "train loss": sum(losses[-eval_steps:]) / eval_steps, |
| | "train samples": total_samples, |
| | "train total tokens": sum(total_tokens), |
| | "forgetting": forgetting, |
| | } |
| | ) |
| | print_verbose(f"Test accuracy: {accuracy:.3f}") |
| |
|
| | except KeyboardInterrupt: |
| | print_verbose("canceled training") |
| | status = TrainStatus.CANCELED |
| | error_msg = "manually canceled" |
| | except torch.OutOfMemoryError as exc: |
| | |
| | print_verbose("out of memory error encountered") |
| | status = TrainStatus.CANCELED |
| | error_msg = str(exc) |
| | except Exception as exc: |
| | print_verbose(f"encountered an error: {exc}") |
| | status = TrainStatus.CANCELED |
| | error_msg = str(exc) |
| |
|
| | toc_train = time.perf_counter() |
| | train_time = toc_train - tic_train - eval_time |
| |
|
| | if status != TrainStatus.CANCELED: |
| | status = TrainStatus.SUCCESS |
| | train_result = TrainResult( |
| | status=status, |
| | train_time=train_time, |
| | accelerator_memory_reserved_log=accelerator_memory_reserved_log, |
| | losses=losses, |
| | metrics=metrics, |
| | error_msg=error_msg, |
| | num_trainable_params=num_trainable_params, |
| | num_total_params=num_params, |
| | ) |
| | return train_result |
| |
|
| |
|
| | def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None: |
| | tic_total = time.perf_counter() |
| | start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat() |
| |
|
| | peft_branch = get_peft_branch() |
| | if peft_branch == "main": |
| | print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======") |
| | else: |
| | print_verbose( |
| | f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======" |
| | ) |
| |
|
| | |
| | peft_config: Optional[PeftConfig] = None |
| | if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)): |
| | peft_config = PeftConfig.from_pretrained(path_experiment) |
| | else: |
| | print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING") |
| | path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS) |
| | train_config = get_train_config(path_train_config) |
| | set_seed(train_config.seed) |
| |
|
| | |
| | accelerator_memory_init = init_accelerator() |
| | tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length) |
| |
|
| | model_info = get_base_model_info(train_config.model_id) |
| | metamath_info = get_dataset_info("meta-math/MetaMathQA") |
| | gsm8k_info = get_dataset_info("openai/gsm8k") |
| | model = get_model( |
| | model_id=train_config.model_id, |
| | dtype=train_config.dtype, |
| | compile=train_config.compile, |
| | attn_implementation=train_config.attn_implementation, |
| | peft_config=peft_config, |
| | autocast_adapter_dtype=train_config.autocast_adapter_dtype, |
| | ) |
| | print_verbose(model) |
| |
|
| | |
| | train_result = train( |
| | model=model, |
| | max_steps=train_config.max_steps, |
| | batch_size=train_config.batch_size, |
| | batch_size_eval=train_config.batch_size_eval, |
| | tokenizer=tokenizer, |
| | accelerator_memory_init=accelerator_memory_init, |
| | eval_steps=train_config.eval_steps, |
| | generation_kwargs=train_config.generation_kwargs, |
| | grad_norm_clip=train_config.grad_norm_clip, |
| | optimizer_type=train_config.optimizer_type, |
| | optimizer_kwargs=train_config.optimizer_kwargs, |
| | query_template=train_config.query_template, |
| | lr_scheduler_arg=train_config.lr_scheduler, |
| | use_amp=train_config.use_amp, |
| | is_adalora=isinstance(peft_config, AdaLoraConfig), |
| | ) |
| |
|
| | if train_result.status == TrainStatus.FAILED: |
| | print_verbose("Training failed, not logging results") |
| | sys.exit(1) |
| |
|
| | file_size = get_file_size( |
| | model, |
| | peft_config=peft_config, |
| | clean=clean, |
| | print_fn=print_verbose, |
| | ) |
| |
|
| | time_total = time.perf_counter() - tic_total |
| | |
| | log_results( |
| | experiment_name=experiment_name, |
| | train_result=train_result, |
| | accelerator_memory_init=accelerator_memory_init, |
| | time_total=time_total, |
| | file_size=file_size, |
| | model_info=model_info, |
| | datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info}, |
| | start_date=start_date, |
| | train_config=train_config, |
| | peft_config=peft_config, |
| | print_fn=print_verbose, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output") |
| | parser.add_argument("path_experiment", type=str, help="Path to the experiment directory") |
| | parser.add_argument( |
| | "--clean", |
| | action="store_true", |
| | help="Delete training artifacts after run finishes (logs are still saved)", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | experiment_name = validate_experiment_path(args.path_experiment) |
| |
|
| | if args.verbose: |
| |
|
| | def print_verbose(*args, **kwargs) -> None: |
| | kwargs["file"] = sys.stderr |
| | print(*args, **kwargs) |
| | else: |
| |
|
| | def print_verbose(*args, **kwargs) -> None: |
| | pass |
| |
|
| | main( |
| | path_experiment=args.path_experiment, |
| | experiment_name=experiment_name, |
| | clean=args.clean, |
| | ) |
| |
|