| | |
| | |
| | |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer, T5ForConditionalGeneration |
| | from trl import SFTTrainer, SFTConfig |
| |
|
| | dataset = load_dataset("mindchain/container-status-de", split="train") |
| | split = dataset.train_test_split(test_size=0.15, seed=42) |
| |
|
| | def fmt(ex): |
| | return {"text": f"Status: {ex['text']}", "label": ex["label"]} |
| |
|
| | train_ds = split["train"].map(fmt, remove_columns=split["train"].column_names) |
| | eval_ds = split["test"].map(fmt, remove_columns=split["test"].column_names) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m") |
| | model = T5ForConditionalGeneration.from_pretrained("google/t5gemma-2-270m") |
| |
|
| | config = SFTConfig( |
| | output_dir="out", |
| | push_to_hub=True, |
| | hub_model_id="mindchain/t5gemma-270m-container-status", |
| | num_train_epochs=5, |
| | per_device_train_batch_size=2, |
| | gradient_accumulation_steps=4, |
| | learning_rate=3e-4, |
| | logging_steps=5, |
| | max_length=256, |
| | report_to="trackio", |
| | ) |
| |
|
| | trainer = SFTTrainer(model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=config) |
| | trainer.train() |
| | trainer.push_to_hub() |
| | print('DONE') |
| |
|