| | --- |
| | license: mit |
| | --- |
| | ## Usage |
| |
|
| | ```python |
| | import torch |
| | from informer_models import InformerConfig, InformerForSequenceClassification |
| | |
| | model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification") |
| | |
| | model.to(device) |
| | model.eval() |
| | y_true = [] |
| | y_pred = [] |
| | for i, batch in enumerate(test_dataloader): |
| | print(f"processing batch {i}") |
| | batch = {k: v.to(device) for k, v in batch.items() if k != "objid"} |
| | with torch.no_grad(): |
| | outputs = model(**batch) |
| | y_true.extend(batch['labels'].cpu().numpy()) |
| | y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy()) |
| | print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}") |
| | ``` |