| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import model_builder |
| import matplotlib.pyplot as plt |
| import argparse |
|
|
| def predict_image(image_path: str, |
| model_path: str, |
| class_names: list): |
| """ |
| Predict class label of an image using a pre-trained model and plot the image with predicted class and probability as title. |
| |
| Args: |
| image_path (str): Path to the input image. |
| model_path (str): Path to the saved PyTorch model. |
| class_names (list): List of class names. |
| |
| Returns: |
| predicted_class (str): Predicted class label. |
| probability_percentage (float): Probability percentage of the predicted class. |
| """ |
| |
| image = Image.open(image_path).convert('RGB') |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((112, 112)), |
| transforms.ToTensor() |
| ]) |
|
|
| |
| image_tensor = transform(image).unsqueeze(0) |
|
|
| |
| model = model_builder.TrashClassificationCNNModel(input_shape=3, |
| hidden_units=15, |
| output_shape=len(class_names) |
| ) |
| |
| model.load_state_dict(torch.load(model_path)) |
| model.eval() |
|
|
| |
| with torch.inference_mode(): |
| output = model(image_tensor) |
| probabilities = torch.softmax(output, dim=1).squeeze().tolist() |
| predicted_index = torch.argmax(output, 1).item() |
| probability_percentage = probabilities[predicted_index] * 100 |
| predicted_class = class_names[predicted_index] |
|
|
| |
| plt.imshow(image_tensor.squeeze().permute(1, 2, 0)) |
| plt.title(f'Predicted Class: {predicted_class} | Probability: {probability_percentage:.2f}%', fontdict={'family': 'serif', |
| 'color': 'black', |
| 'weight': 'normal', |
| 'size': 16,}) |
| plt.axis(False) |
| plt.show() |
|
|
| return predicted_class, probability_percentage |
|
|
| |
| parser = argparse.ArgumentParser(description="For Prediction of an Image using a Loaded Model") |
| parser.add_argument("--image", type=str, default=None, help="Path to the image to make predictions on.") |
| parser.add_argument("--model_path", type=str, default=None, help="Path to the Trained Models State Dict.") |
| args = parser.parse_args() |
|
|
| IMAGE_PATH = args.image |
| MODEL_PATH = args.model_path |
|
|
| |
| if not IMAGE_PATH: |
| print("Please Enter Image Path using --image") |
| raise SystemExit |
|
|
| if not MODEL_PATH: |
| print("Please Enter Model PAth using --model_path") |
| raise SystemExit |
|
|
| |
| try: |
| |
| predict_image(image_path=IMAGE_PATH, |
| model_path=MODEL_PATH, |
| class_names=['Cardboard', |
| 'Food Organics', |
| 'Glass', |
| 'Metal', |
| 'Miscellaneous Trash', |
| 'Paper', |
| 'Plastic', |
| 'Textile Trash', |
| 'Vegetation']) |
| except Exception as exception: |
| print("INVALID MODEL_PATH OR INVALID IMAGE PATH") |
| print(f"\n{exception}") |