| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from mlxtend.evaluate import confusion_matrix |
| from mlxtend.plotting import plot_confusion_matrix |
| import model_builder |
|
|
| def plot_confusion_Matrix(model_path, dataloader, class_names, device, figsize=(12, 12)): |
| """ |
| Generate and plot confusion matrix using mlxtend library from a PyTorch model and DataLoader. |
| |
| Args: |
| model: PyTorch model's path eg(".pth" or ".pt"). |
| dataloader: DataLoader instance for the dataset. |
| class_names (list): List of class names. |
| device: Target device to compute on (e.g., "cuda" or "cpu"). |
| figsize (tuple): Figure size. |
| |
| Returns: |
| None |
| """ |
|
|
| |
| model = model_builder.TrashClassificationCNNModel(input_shape=3, |
| hidden_units=15, |
| output_shape=len(class_names) |
| ) |
| |
| model.load_state_dict(torch.load(model_path)) |
|
|
| model.eval() |
| y_true = [] |
| y_pred = [] |
|
|
| with torch.no_grad(): |
| for X, y in dataloader: |
| X = X.to(device) |
| y_true.extend(y.cpu().numpy()) |
| y_logit = model(X) |
| y_pred.extend(torch.argmax(y_logit, dim=1).cpu().numpy()) |
|
|
| confmat = confusion_matrix(y_target=y_true, y_predicted=y_pred, binary=False) |
| |
| |
| fig, ax = plot_confusion_matrix(conf_mat=confmat, |
| class_names=class_names, |
| figsize=figsize) |
| |
| plt.title('Confusion Matrix') |
| plt.xlabel('Predicted Label') |
| plt.ylabel('True Label') |
| plt.show() |
|
|
| def plot_metrics(metrics): |
| """ |
| Plots training and testing loss and accuracy. |
| |
| Args: |
| metrics (dict): A dictionary containing training and testing loss and accuracy. |
| |
| Returns: |
| None |
| """ |
| epochs = range(1, len(metrics['train_loss']) + 1) |
|
|
| |
| plt.figure(figsize=(12, 5)) |
| plt.subplot(1, 2, 1) |
| plt.plot(epochs, metrics['train_loss'], 'b', label='Training loss') |
| plt.plot(epochs, metrics['test_loss'], 'r', label='Testing loss') |
| plt.title('Training and Testing Loss') |
| plt.xlabel('Epochs') |
| plt.ylabel('Loss') |
| plt.legend() |
|
|
| |
| plt.subplot(1, 2, 2) |
| plt.plot(epochs, metrics['train_acc'], 'b', label='Training accuracy') |
| plt.plot(epochs, metrics['test_acc'], 'r', label='Testing accuracy') |
| plt.title('Training and Testing Accuracy') |
| plt.xlabel('Epochs') |
| plt.ylabel('Accuracy') |
| plt.legend() |
|
|
| |
| plt.tight_layout() |
| plt.show() |
|
|