| | import sys |
| | from pathlib import Path |
| | sys.path.append(str(Path(__file__).resolve().parent.parent)) |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import json |
| | from models.model import Microformer |
| | from config import * |
| |
|
| | |
| | |
| | |
| | with open("data/vocab.json", "r") as f: |
| | vocab = json.load(f) |
| | stoi = vocab["stoi"] |
| | itos = {int(k): v for k, v in vocab["itos"].items()} |
| | VOCAB_SIZE = len(stoi) |
| |
|
| | data = torch.load("data/train.pt") |
| | SEQ_LEN = MAX_SEQ_LEN |
| | BATCH_SIZE = 32 |
| |
|
| | |
| | num_batches = len(data) // (SEQ_LEN * BATCH_SIZE) |
| | trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE |
| | data = data[:trimmed_len] |
| | data = data.view(BATCH_SIZE, -1) |
| |
|
| | def get_batch(start_idx): |
| | x = data[:, start_idx:start_idx+SEQ_LEN] |
| | y = data[:, start_idx+1:start_idx+SEQ_LEN+1] |
| | return x, y |
| |
|
| | |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | |
| | |
| | model = Microformer( |
| | VOCAB_SIZE, |
| | EMBED_DIM, |
| | NUM_HEADS, |
| | FF_DIM, |
| | NUM_LAYERS, |
| | MAX_SEQ_LEN, |
| | long_term_adapter_dim=ADAPTER_DIM, |
| | session_adapter_dim=ADAPTER_DIM |
| | ) |
| | model.to(device) |
| |
|
| | |
| | |
| | |
| | model.freeze_except_adapters(session_only=False, include_output=True) |
| | |
| | for layer in model.layers: |
| | if getattr(layer, 'session_adapter', None) is not None: |
| | for param in layer.session_adapter.parameters(): |
| | param.requires_grad = False |
| |
|
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) |
| |
|
| | |
| | |
| | |
| | for epoch in range(6): |
| | for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN): |
| | inputs, targets = get_batch(i) |
| | inputs, targets = inputs.to(device), targets.to(device) |
| | optimizer.zero_grad() |
| | out = model(inputs) |
| | loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1)) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
| |
|
| | torch.save(model.state_dict(), "microformer.pt") |
| |
|
| | |
| | |
| | |
| | def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64): |
| | |
| | ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")] |
| | if len(ids) < 2: |
| | return None |
| |
|
| | ids = ids[:max_len + 1] |
| | input_ids = ids[:-1] |
| | target_ids = ids[1:] |
| | input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids)) |
| | target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids)) |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
| | target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device) |
| |
|
| | model.train() |
| | logits = model(input_tensor) |
| | logits = logits.view(-1, logits.size(-1)) |
| | targets = target_tensor.view(-1) |
| | loss = loss_fn(logits, targets) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | model.eval() |
| | return loss.item() |
| |
|
| | |
| | |
| | |
| | def reset_session_adapters(model): |
| | for layer in model.layers: |
| | if getattr(layer, 'session_adapter', None) is not None: |
| | for param in layer.session_adapter.parameters(): |
| | if param.data is not None: |
| | nn.init.zeros_(param.data) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|