Upload 21 files
Browse files- .gitattributes +2 -0
- Dockerfile +37 -0
- Notebooks/PyTorch_CNN_Image_Classification_Hugging_Face.ipynb +0 -0
- Notebooks/PyTorch_Image_Classification.ipynb +3 -0
- Notebooks/PyTorch_Xception_CNN_Image_Classification.ipynb +0 -0
- Results/PyTorch_Unified_Model_Comparison.mp4 +3 -0
- app.py +395 -0
- docker-compose.yml +25 -0
- docker-quickstart.bat +16 -0
- docker-quickstart.sh +16 -0
- model_handlers/__init__.py +1 -0
- model_handlers/basic_cnn_handler.py +191 -0
- model_handlers/hugging_face_handler.py +71 -0
- model_handlers/xception_handler.py +110 -0
- models/basic_cnn/cnn_model_statedict_20260226_034332.pth +3 -0
- models/basic_cnn/deployment_config.json +11 -0
- models/basic_cnn/model_metadata_20260226_034332.json +31 -0
- models/hugging_face/config.json +44 -0
- models/hugging_face/model.safetensors +3 -0
- models/hugging_face/preprocessor_config.json +23 -0
- models/xception/best_model_finetuned_full.pt +3 -0
- requirements.txt +33 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Notebooks/PyTorch_Image_Classification.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Results/PyTorch_Unified_Model_Comparison.mp4 filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use official Python runtime as a parent image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory in container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies required for PyTorch and image processing
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
libssl-dev \
|
| 12 |
+
libffi-dev \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy requirements first for better caching
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install Python dependencies
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy application code
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Expose port for Gradio
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Set environment variables
|
| 28 |
+
ENV PYTHONUNBUFFERED=1
|
| 29 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 30 |
+
ENV GRADIO_SERVER_PORT=7860
|
| 31 |
+
|
| 32 |
+
# Health check
|
| 33 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
| 34 |
+
CMD python -c "import sys; sys.exit(0)" || exit 1
|
| 35 |
+
|
| 36 |
+
# Run the application
|
| 37 |
+
CMD ["python", "app.py"]
|
Notebooks/PyTorch_CNN_Image_Classification_Hugging_Face.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Notebooks/PyTorch_Image_Classification.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf12bb81b637521294883fad25cac2866334c1818b1d2976e53c2b085cdc5641
|
| 3 |
+
size 28301580
|
Notebooks/PyTorch_Xception_CNN_Image_Classification.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Results/PyTorch_Unified_Model_Comparison.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35c3480db4bb26d609bd5cddc64ff007dab3957fe2203a95d592558a9e42d28d
|
| 3 |
+
size 18008325
|
app.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
+
from typing import Tuple, Dict
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import torch
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
# Import model handlers
|
| 13 |
+
from model_handlers.basic_cnn_handler import BasicCNNModel
|
| 14 |
+
from model_handlers.hugging_face_handler import HuggingFaceModel
|
| 15 |
+
from model_handlers.xception_handler import XceptionModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Global Configuration
|
| 19 |
+
|
| 20 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
+
MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| 22 |
+
|
| 23 |
+
MODEL_1_DIR = os.path.join(MODELS_DIR, "basic_cnn")
|
| 24 |
+
MODEL_2_DIR = os.path.join(MODELS_DIR, "hugging_face")
|
| 25 |
+
MODEL_3_DIR = os.path.join(MODELS_DIR, "xception")
|
| 26 |
+
|
| 27 |
+
# Model instances (loaded at startup)
|
| 28 |
+
basic_cnn_model = None
|
| 29 |
+
hugging_face_model = None
|
| 30 |
+
xception_model = None
|
| 31 |
+
|
| 32 |
+
# Dataset for random image selection
|
| 33 |
+
dataset = None
|
| 34 |
+
DATASET_NAME = "AIOmarRehan/Vehicles"
|
| 35 |
+
|
| 36 |
+
MODELS_INFO = {
|
| 37 |
+
"Model 1: Basic CNN": {
|
| 38 |
+
"description": "Custom CNN architecture with 4 Conv blocks and BatchNorm",
|
| 39 |
+
"path": MODEL_1_DIR,
|
| 40 |
+
"handler_class": BasicCNNModel
|
| 41 |
+
},
|
| 42 |
+
"Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)": {
|
| 43 |
+
"description": "Pre-trained transformer-based model from Hugging Face (DeiT-Tiny | Meta)",
|
| 44 |
+
"path": MODEL_2_DIR,
|
| 45 |
+
"handler_class": HuggingFaceModel
|
| 46 |
+
},
|
| 47 |
+
"Model 3: Xception CNN": {
|
| 48 |
+
"description": "Fine-tuned Xception architecture using timm library",
|
| 49 |
+
"path": MODEL_3_DIR,
|
| 50 |
+
"handler_class": XceptionModel
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Model Loading
|
| 56 |
+
|
| 57 |
+
def load_models():
|
| 58 |
+
"""Load all three models at startup"""
|
| 59 |
+
global basic_cnn_model, hugging_face_model, xception_model
|
| 60 |
+
|
| 61 |
+
print("\n" + "="*60)
|
| 62 |
+
print("Loading Models...")
|
| 63 |
+
print("="*60)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
print("\n[1/3] Loading Basic CNN Model...")
|
| 67 |
+
basic_cnn_model = BasicCNNModel(MODEL_1_DIR)
|
| 68 |
+
print("Basic CNN Model loaded successfully")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Failed to load Basic CNN Model: {e}")
|
| 71 |
+
basic_cnn_model = None
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
print("\n[2/3] Loading Hugging Face (DeiT-Tiny | Meta) Model...")
|
| 75 |
+
hugging_face_model = HuggingFaceModel(MODEL_2_DIR)
|
| 76 |
+
print("Hugging Face Model loaded successfully")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Failed to load Hugging Face Model: {e}")
|
| 79 |
+
hugging_face_model = None
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
print("\n[3/3] Loading Xception Model...")
|
| 83 |
+
xception_model = XceptionModel(MODEL_3_DIR)
|
| 84 |
+
print("Xception Model loaded successfully")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Failed to load Xception Model: {e}")
|
| 87 |
+
xception_model = None
|
| 88 |
+
|
| 89 |
+
print("\n" + "="*60)
|
| 90 |
+
print("Model Loading Complete!")
|
| 91 |
+
print("="*60 + "\n")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_dataset_split():
|
| 95 |
+
"""Load the dataset for random image selection"""
|
| 96 |
+
global dataset
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
print("\nLoading dataset from Hugging Face...")
|
| 100 |
+
# Load the test split of the dataset
|
| 101 |
+
dataset = load_dataset(DATASET_NAME, split="train", trust_remote_code=True)
|
| 102 |
+
print(f"Dataset loaded successfully: {len(dataset)} images available")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Failed to load dataset: {e}")
|
| 105 |
+
print("Random image feature will be disabled")
|
| 106 |
+
dataset = None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_random_image():
|
| 110 |
+
"""Get a random image from the dataset"""
|
| 111 |
+
if dataset is None:
|
| 112 |
+
print("Dataset not loaded, attempting to load...")
|
| 113 |
+
load_dataset_split()
|
| 114 |
+
|
| 115 |
+
if dataset is None:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
# Select a random index
|
| 120 |
+
random_idx = random.randint(0, len(dataset) - 1)
|
| 121 |
+
sample = dataset[random_idx]
|
| 122 |
+
|
| 123 |
+
# Get the image (usually stored as 'image' or 'img' key)
|
| 124 |
+
if 'image' in sample:
|
| 125 |
+
img = sample['image']
|
| 126 |
+
elif 'img' in sample:
|
| 127 |
+
img = sample['img']
|
| 128 |
+
else:
|
| 129 |
+
# Try to find the first PIL Image in the sample
|
| 130 |
+
for value in sample.values():
|
| 131 |
+
if isinstance(value, Image.Image):
|
| 132 |
+
img = value
|
| 133 |
+
break
|
| 134 |
+
else:
|
| 135 |
+
print(f"Could not find image in sample keys: {sample.keys()}")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
print(f"Loaded random image from index {random_idx}")
|
| 139 |
+
return img
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"Error loading random image: {e}")
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Prediction Functions
|
| 146 |
+
|
| 147 |
+
def predict_with_model_1(image: Image.Image) -> Tuple[str, float, Dict]:
|
| 148 |
+
"""Predict with Basic CNN Model"""
|
| 149 |
+
if basic_cnn_model is None:
|
| 150 |
+
return "Model 1: Error", 0.0, {}
|
| 151 |
+
try:
|
| 152 |
+
label, confidence, prob_dict = basic_cnn_model.predict(image)
|
| 153 |
+
return label, confidence, prob_dict
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Error in Model 1 prediction: {e}")
|
| 156 |
+
return "Error", 0.0, {}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def predict_with_model_2(image: Image.Image) -> Tuple[str, float, Dict]:
|
| 160 |
+
"""Predict with Hugging Face (DeiT-Tiny | Meta) Model"""
|
| 161 |
+
if hugging_face_model is None:
|
| 162 |
+
return "Model 2: Error", 0.0, {}
|
| 163 |
+
try:
|
| 164 |
+
label, confidence, prob_dict = hugging_face_model.predict(image)
|
| 165 |
+
return label, confidence, prob_dict
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"Error in Model 2 prediction: {e}")
|
| 168 |
+
return "Error", 0.0, {}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def predict_with_model_3(image: Image.Image) -> Tuple[str, float, Dict]:
|
| 172 |
+
"""Predict with Xception Model"""
|
| 173 |
+
if xception_model is None:
|
| 174 |
+
return "Model 3: Error", 0.0, {}
|
| 175 |
+
try:
|
| 176 |
+
label, confidence, prob_dict = xception_model.predict(image)
|
| 177 |
+
return label, confidence, prob_dict
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error in Model 3 prediction: {e}")
|
| 180 |
+
return "Error", 0.0, {}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def predict_all_models(image: Image.Image):
|
| 184 |
+
if image is None:
|
| 185 |
+
empty_result = {"Model": "N/A", "Prediction": "No image", "Confidence": 0.0}
|
| 186 |
+
empty_probs = {}
|
| 187 |
+
empty_consensus = "<p>Please upload an image to see results</p>"
|
| 188 |
+
return empty_result, empty_result, empty_result, "Please upload an image", empty_probs, empty_probs, empty_probs, empty_consensus
|
| 189 |
+
|
| 190 |
+
print("\n" + "="*60)
|
| 191 |
+
print("Running Predictions with All Models...")
|
| 192 |
+
print("="*60)
|
| 193 |
+
|
| 194 |
+
# Run predictions in parallel
|
| 195 |
+
with ThreadPoolExecutor(max_workers=3) as executor:
|
| 196 |
+
future_1 = executor.submit(predict_with_model_1, image)
|
| 197 |
+
future_2 = executor.submit(predict_with_model_2, image)
|
| 198 |
+
future_3 = executor.submit(predict_with_model_3, image)
|
| 199 |
+
|
| 200 |
+
# Wait for all predictions to complete
|
| 201 |
+
result_1_label, result_1_conf, result_1_probs = future_1.result()
|
| 202 |
+
result_2_label, result_2_conf, result_2_probs = future_2.result()
|
| 203 |
+
result_3_label, result_3_conf, result_3_probs = future_3.result()
|
| 204 |
+
|
| 205 |
+
# Format results for display
|
| 206 |
+
result_1 = {
|
| 207 |
+
"Model": "Basic CNN",
|
| 208 |
+
"Prediction": result_1_label,
|
| 209 |
+
"Confidence": f"{result_1_conf * 100:.2f}%"
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
result_2 = {
|
| 213 |
+
"Model": "Hugging Face (DeiT-Tiny | Meta)",
|
| 214 |
+
"Prediction": result_2_label,
|
| 215 |
+
"Confidence": f"{result_2_conf * 100:.2f}%"
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
result_3 = {
|
| 219 |
+
"Model": "Xception",
|
| 220 |
+
"Prediction": result_3_label,
|
| 221 |
+
"Confidence": f"{result_3_conf * 100:.2f}%"
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
# Check if all models agree
|
| 225 |
+
all_agree = result_1_label == result_2_label == result_3_label
|
| 226 |
+
|
| 227 |
+
# Create comparison text with HTML styling
|
| 228 |
+
if all_agree:
|
| 229 |
+
consensus_html = f"""
|
| 230 |
+
<div style="background-color: #d4edda; border: 2px solid #28a745; border-radius: 8px; padding: 20px; text-align: center;">
|
| 231 |
+
<h3 style="color: #155724; margin: 0; font-size: 24px;">All Models Agree!</h3>
|
| 232 |
+
<p style="color: #155724; margin: 10px 0 0 0; font-size: 18px; font-weight: bold;">{result_1_label}</p>
|
| 233 |
+
</div>
|
| 234 |
+
"""
|
| 235 |
+
else:
|
| 236 |
+
consensus_html = f"""
|
| 237 |
+
<div style="background-color: #f8d7da; border: 2px solid #dc3545; border-radius: 8px; padding: 20px; text-align: center;">
|
| 238 |
+
<h3 style="color: #721c24; margin: 0; font-size: 24px;">Models Disagree</h3>
|
| 239 |
+
<p style="color: #721c24; margin: 10px 0 0 0; font-size: 16px;">Check predictions below for details</p>
|
| 240 |
+
</div>
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
comparison_text = f"""
|
| 244 |
+
## Comparison Results
|
| 245 |
+
|
| 246 |
+
**Model 1 (Basic CNN):** {result_1_label} ({result_1_conf * 100:.2f}%)
|
| 247 |
+
|
| 248 |
+
**Model 2 (Hugging Face (DeiT-Tiny | Meta)):** {result_2_label} ({result_2_conf * 100:.2f}%)
|
| 249 |
+
|
| 250 |
+
**Model 3 (Xception):** {result_3_label} ({result_3_conf * 100:.2f}%)
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
print(f"Prediction 1: {result_1_label} ({result_1_conf * 100:.2f}%)")
|
| 254 |
+
print(f"Prediction 2: {result_2_label} ({result_2_conf * 100:.2f}%)")
|
| 255 |
+
print(f"Prediction 3: {result_3_label} ({result_3_conf * 100:.2f}%)")
|
| 256 |
+
print(f"Consensus: {'All agree!' if all_agree else 'Disagreement detected'}")
|
| 257 |
+
print("="*60 + "\n")
|
| 258 |
+
|
| 259 |
+
return result_1, result_2, result_3, comparison_text, result_1_probs, result_2_probs, result_3_probs, consensus_html
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Gradio Interface
|
| 263 |
+
|
| 264 |
+
def build_interface() -> gr.Blocks:
|
| 265 |
+
with gr.Blocks(
|
| 266 |
+
title="PyTorch Unified Model Comparison",
|
| 267 |
+
theme=gr.themes.Soft()
|
| 268 |
+
) as demo:
|
| 269 |
+
|
| 270 |
+
# Header
|
| 271 |
+
gr.Markdown("""
|
| 272 |
+
# PyTorch Unified Model Comparison
|
| 273 |
+
|
| 274 |
+
Upload an image and compare predictions from three different PyTorch models **simultaneously**.
|
| 275 |
+
|
| 276 |
+
This tool helps you understand how different architectures (Basic CNN, Transformers, Xception)
|
| 277 |
+
classify the same image and identify where they agree or disagree.
|
| 278 |
+
""")
|
| 279 |
+
|
| 280 |
+
# Model Information
|
| 281 |
+
with gr.Accordion("Model Information", open=False):
|
| 282 |
+
gr.Markdown(f"""
|
| 283 |
+
### Model 1: Basic CNN
|
| 284 |
+
- **Description:** {MODELS_INFO['Model 1: Basic CNN']['description']}
|
| 285 |
+
- **Architecture:** 4 Conv blocks + BatchNorm + Global Avg Pooling
|
| 286 |
+
- **Input Size:** 224×224
|
| 287 |
+
|
| 288 |
+
### Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)
|
| 289 |
+
- **Description:** {MODELS_INFO['Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)']['description']}
|
| 290 |
+
- **Framework:** transformers library
|
| 291 |
+
|
| 292 |
+
### Model 3: Xception CNN
|
| 293 |
+
- **Description:** {MODELS_INFO['Model 3: Xception CNN']['description']}
|
| 294 |
+
- **Architecture:** Fine-tuned Xception with timm
|
| 295 |
+
""")
|
| 296 |
+
|
| 297 |
+
# Input Section
|
| 298 |
+
with gr.Row():
|
| 299 |
+
with gr.Column():
|
| 300 |
+
image_input = gr.Image(
|
| 301 |
+
type="pil",
|
| 302 |
+
label="Upload Image",
|
| 303 |
+
sources=["upload", "webcam"]
|
| 304 |
+
)
|
| 305 |
+
predict_btn = gr.Button("Predict with All Models", variant="primary", size="lg")
|
| 306 |
+
random_img_btn = gr.Button("Load Random Image from Dataset", variant="secondary", size="lg")
|
| 307 |
+
|
| 308 |
+
# Output Section
|
| 309 |
+
gr.Markdown("## Results")
|
| 310 |
+
|
| 311 |
+
with gr.Row():
|
| 312 |
+
with gr.Column():
|
| 313 |
+
result_1_box = gr.JSON(label="Model 1: Basic CNN")
|
| 314 |
+
with gr.Column():
|
| 315 |
+
result_2_box = gr.JSON(label="Model 2: Hugging Face (DeiT-Tiny)")
|
| 316 |
+
with gr.Column():
|
| 317 |
+
result_3_box = gr.JSON(label="Model 3: Xception")
|
| 318 |
+
|
| 319 |
+
# Comparison Section
|
| 320 |
+
comparison_output = gr.Markdown(label="Comparison Summary")
|
| 321 |
+
|
| 322 |
+
# Consensus Indicator (HTML for colored styling)
|
| 323 |
+
consensus_output = gr.HTML(value="<p></p>")
|
| 324 |
+
|
| 325 |
+
# Class Probabilities Section
|
| 326 |
+
gr.Markdown("## Class Probabilities")
|
| 327 |
+
|
| 328 |
+
with gr.Row():
|
| 329 |
+
with gr.Column():
|
| 330 |
+
probs_1 = gr.Label(label="Model 1: Basic CNN | Probabilities")
|
| 331 |
+
with gr.Column():
|
| 332 |
+
probs_2 = gr.Label(label="Model 2: DeiT-Tiny | Meta | Probabilities")
|
| 333 |
+
with gr.Column():
|
| 334 |
+
probs_3 = gr.Label(label="Model 3: Xception | Probabilities")
|
| 335 |
+
|
| 336 |
+
# Connect button click
|
| 337 |
+
predict_btn.click(
|
| 338 |
+
fn=predict_all_models,
|
| 339 |
+
inputs=image_input,
|
| 340 |
+
outputs=[result_1_box, result_2_box, result_3_box, comparison_output, probs_1, probs_2, probs_3, consensus_output]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Also trigger on image upload
|
| 344 |
+
image_input.change(
|
| 345 |
+
fn=predict_all_models,
|
| 346 |
+
inputs=image_input,
|
| 347 |
+
outputs=[result_1_box, result_2_box, result_3_box, comparison_output, probs_1, probs_2, probs_3, consensus_output]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Connect random image button
|
| 351 |
+
random_img_btn.click(
|
| 352 |
+
fn=get_random_image,
|
| 353 |
+
inputs=None,
|
| 354 |
+
outputs=image_input
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Footer
|
| 358 |
+
gr.Markdown("""
|
| 359 |
+
---
|
| 360 |
+
|
| 361 |
+
**Available Classes:** Auto Rickshaws | Bikes | Cars | Motorcycles | Planes | Ships | Trains
|
| 362 |
+
|
| 363 |
+
**Dataset:** Random images are loaded from [AIOmarRehan/Vehicles](https://huggingface.co/datasets/AIOmarRehan/Vehicles) on Hugging Face
|
| 364 |
+
|
| 365 |
+
This unified application allows real-time comparison of three different deep learning models
|
| 366 |
+
to understand their individual strengths and weaknesses.
|
| 367 |
+
""")
|
| 368 |
+
|
| 369 |
+
return demo
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# Main Entry Point
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
# Load all models at startup
|
| 376 |
+
load_models()
|
| 377 |
+
|
| 378 |
+
# Load dataset for random image selection
|
| 379 |
+
load_dataset_split()
|
| 380 |
+
|
| 381 |
+
# Build and launch Gradio interface
|
| 382 |
+
demo = build_interface()
|
| 383 |
+
|
| 384 |
+
server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
|
| 385 |
+
server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
| 386 |
+
|
| 387 |
+
print(f"\nLaunching Gradio Interface on {server_name}:{server_port}")
|
| 388 |
+
print("Open your browser and navigate to http://localhost:7860\n")
|
| 389 |
+
|
| 390 |
+
demo.launch(
|
| 391 |
+
server_name=server_name,
|
| 392 |
+
server_port=server_port,
|
| 393 |
+
share=False,
|
| 394 |
+
show_error=True
|
| 395 |
+
)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
unified-model-app:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
container_name: pytorch-unified-model-comparison
|
| 9 |
+
ports:
|
| 10 |
+
- "7860:7860"
|
| 11 |
+
environment:
|
| 12 |
+
PYTHONUNBUFFERED: 1
|
| 13 |
+
GRADIO_SERVER_NAME: 0.0.0.0
|
| 14 |
+
GRADIO_SERVER_PORT: 7860
|
| 15 |
+
volumes:
|
| 16 |
+
# Optional: Mount models directory if you want to update models without rebuilding
|
| 17 |
+
- ./models:/app/models
|
| 18 |
+
restart: unless-stopped
|
| 19 |
+
# Optional: Increase memory limit if needed
|
| 20 |
+
# deploy:
|
| 21 |
+
# resources:
|
| 22 |
+
# limits:
|
| 23 |
+
# memory: 8G
|
| 24 |
+
# reservations:
|
| 25 |
+
# memory: 4G
|
docker-quickstart.bat
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
REM Quick Docker Compose startup script for Windows
|
| 3 |
+
|
| 4 |
+
echo Building Docker image...
|
| 5 |
+
docker-compose build
|
| 6 |
+
|
| 7 |
+
echo.
|
| 8 |
+
echo Starting container...
|
| 9 |
+
docker-compose up -d
|
| 10 |
+
|
| 11 |
+
echo.
|
| 12 |
+
echo Container is running!
|
| 13 |
+
echo Access the app at: http://localhost:7860
|
| 14 |
+
echo.
|
| 15 |
+
echo To stop the container, run: docker-compose down
|
| 16 |
+
echo To view logs, run: docker-compose logs -f
|
docker-quickstart.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Quick Docker Compose startup script for Linux/Mac
|
| 3 |
+
|
| 4 |
+
echo "Building Docker image..."
|
| 5 |
+
docker-compose build
|
| 6 |
+
|
| 7 |
+
echo ""
|
| 8 |
+
echo "Starting container..."
|
| 9 |
+
docker-compose up -d
|
| 10 |
+
|
| 11 |
+
echo ""
|
| 12 |
+
echo "Container is running!"
|
| 13 |
+
echo "Access the app at: http://localhost:7860"
|
| 14 |
+
echo ""
|
| 15 |
+
echo "To stop the container, run: docker-compose down"
|
| 16 |
+
echo "To view logs, run: docker-compose logs -f"
|
model_handlers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize model_handlers package
|
model_handlers/basic_cnn_handler.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Tuple, Dict
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BasicCNNModel:
|
| 11 |
+
|
| 12 |
+
class CNN(torch.nn.Module):
|
| 13 |
+
def __init__(self, num_classes, dropout_rate=0.4):
|
| 14 |
+
super(BasicCNNModel.CNN, self).__init__()
|
| 15 |
+
self.dropout_rate = dropout_rate
|
| 16 |
+
|
| 17 |
+
# Conv Block 1: 3 → 32 channels
|
| 18 |
+
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False)
|
| 19 |
+
self.bn1 = torch.nn.BatchNorm2d(32)
|
| 20 |
+
self.relu1 = torch.nn.ReLU(inplace=True)
|
| 21 |
+
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
| 22 |
+
|
| 23 |
+
# Conv Block 2: 32 → 64 channels
|
| 24 |
+
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
|
| 25 |
+
self.bn2 = torch.nn.BatchNorm2d(64)
|
| 26 |
+
self.relu2 = torch.nn.ReLU(inplace=True)
|
| 27 |
+
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
| 28 |
+
|
| 29 |
+
# Conv Block 3: 64 → 128 channels
|
| 30 |
+
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
|
| 31 |
+
self.bn3 = torch.nn.BatchNorm2d(128)
|
| 32 |
+
self.relu3 = torch.nn.ReLU(inplace=True)
|
| 33 |
+
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
| 34 |
+
|
| 35 |
+
# Conv Block 4: 128 → 256 channels
|
| 36 |
+
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False)
|
| 37 |
+
self.bn4 = torch.nn.BatchNorm2d(256)
|
| 38 |
+
self.relu4 = torch.nn.ReLU(inplace=True)
|
| 39 |
+
self.maxpool4 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
| 40 |
+
|
| 41 |
+
# Global Average Pooling (adaptive pooling to 1x1)
|
| 42 |
+
self.global_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
| 43 |
+
|
| 44 |
+
# Classifier head with Dropout
|
| 45 |
+
self.dropout1 = torch.nn.Dropout(p=dropout_rate)
|
| 46 |
+
self.fc1 = torch.nn.Linear(256, 512)
|
| 47 |
+
self.fc1_relu = torch.nn.ReLU(inplace=True)
|
| 48 |
+
|
| 49 |
+
self.dropout2 = torch.nn.Dropout(p=dropout_rate)
|
| 50 |
+
self.fc2 = torch.nn.Linear(512, num_classes)
|
| 51 |
+
|
| 52 |
+
# Initialize weights
|
| 53 |
+
self._init_weights()
|
| 54 |
+
|
| 55 |
+
def _init_weights(self):
|
| 56 |
+
for m in self.modules():
|
| 57 |
+
if isinstance(m, torch.nn.Conv2d):
|
| 58 |
+
torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 59 |
+
elif isinstance(m, torch.nn.BatchNorm2d):
|
| 60 |
+
torch.nn.init.constant_(m.weight, 1)
|
| 61 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 62 |
+
elif isinstance(m, torch.nn.Linear):
|
| 63 |
+
torch.nn.init.normal_(m.weight, 0, 0.01)
|
| 64 |
+
if m.bias is not None:
|
| 65 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
# Block 1
|
| 69 |
+
x = self.conv1(x)
|
| 70 |
+
x = self.bn1(x)
|
| 71 |
+
x = self.relu1(x)
|
| 72 |
+
x = self.maxpool1(x)
|
| 73 |
+
|
| 74 |
+
# Block 2
|
| 75 |
+
x = self.conv2(x)
|
| 76 |
+
x = self.bn2(x)
|
| 77 |
+
x = self.relu2(x)
|
| 78 |
+
x = self.maxpool2(x)
|
| 79 |
+
|
| 80 |
+
# Block 3
|
| 81 |
+
x = self.conv3(x)
|
| 82 |
+
x = self.bn3(x)
|
| 83 |
+
x = self.relu3(x)
|
| 84 |
+
x = self.maxpool3(x)
|
| 85 |
+
|
| 86 |
+
# Block 4
|
| 87 |
+
x = self.conv4(x)
|
| 88 |
+
x = self.bn4(x)
|
| 89 |
+
x = self.relu4(x)
|
| 90 |
+
x = self.maxpool4(x)
|
| 91 |
+
|
| 92 |
+
# Global Average Pooling
|
| 93 |
+
x = self.global_avg_pool(x)
|
| 94 |
+
x = x.view(x.size(0), -1) # Flatten
|
| 95 |
+
|
| 96 |
+
# Classifier head
|
| 97 |
+
x = self.dropout1(x)
|
| 98 |
+
x = self.fc1(x)
|
| 99 |
+
x = self.fc1_relu(x)
|
| 100 |
+
|
| 101 |
+
x = self.dropout2(x)
|
| 102 |
+
x = self.fc2(x)
|
| 103 |
+
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
def __init__(self, model_dir: str):
|
| 107 |
+
self.model_dir = model_dir
|
| 108 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 109 |
+
self.model = None
|
| 110 |
+
self.transform = None
|
| 111 |
+
self.class_names = []
|
| 112 |
+
self.metadata = None
|
| 113 |
+
|
| 114 |
+
print(f"[BasicCNN] Using device: {self.device}")
|
| 115 |
+
self._load_model()
|
| 116 |
+
|
| 117 |
+
def _load_config(self) -> Dict:
|
| 118 |
+
config_path = os.path.join(self.model_dir, "deployment_config.json")
|
| 119 |
+
if not os.path.exists(config_path):
|
| 120 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 121 |
+
with open(config_path, "r") as f:
|
| 122 |
+
return json.load(f)
|
| 123 |
+
|
| 124 |
+
def _load_metadata(self, metadata_path: str) -> Dict:
|
| 125 |
+
with open(metadata_path, "r") as f:
|
| 126 |
+
return json.load(f)
|
| 127 |
+
|
| 128 |
+
def _build_transforms(self, mean: List[float], std: List[float]) -> transforms.Compose:
|
| 129 |
+
return transforms.Compose([
|
| 130 |
+
transforms.Resize((224, 224)),
|
| 131 |
+
transforms.ToTensor(),
|
| 132 |
+
transforms.Normalize(mean=mean, std=std),
|
| 133 |
+
])
|
| 134 |
+
|
| 135 |
+
def _load_model(self):
|
| 136 |
+
try:
|
| 137 |
+
config = self._load_config()
|
| 138 |
+
metadata_path = os.path.join(self.model_dir, config["metadata"])
|
| 139 |
+
state_dict_path = os.path.join(self.model_dir, config["model_state_dict"])
|
| 140 |
+
|
| 141 |
+
self.metadata = self._load_metadata(metadata_path)
|
| 142 |
+
|
| 143 |
+
# Load model
|
| 144 |
+
self.model = self.CNN(num_classes=self.metadata["num_classes"], dropout_rate=0.4)
|
| 145 |
+
state_dict = torch.load(state_dict_path, map_location=self.device)
|
| 146 |
+
self.model.load_state_dict(state_dict)
|
| 147 |
+
self.model.to(self.device)
|
| 148 |
+
self.model.eval()
|
| 149 |
+
|
| 150 |
+
# Load transforms
|
| 151 |
+
self.transform = self._build_transforms(
|
| 152 |
+
self.metadata["normalization_mean"],
|
| 153 |
+
self.metadata["normalization_std"]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Load class names
|
| 157 |
+
class_names_dict = self.metadata.get("class_names", {})
|
| 158 |
+
self.class_names = [class_names_dict[str(i)] for i in range(len(class_names_dict))]
|
| 159 |
+
|
| 160 |
+
print(f"[BasicCNN] Model loaded successfully. Classes: {self.class_names}")
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"[BasicCNN] Error loading model: {e}")
|
| 164 |
+
raise
|
| 165 |
+
|
| 166 |
+
def predict(self, image: Image.Image) -> Tuple[str, float, Dict[str, float]]:
|
| 167 |
+
if image is None:
|
| 168 |
+
return "No image provided", 0.0, {}
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
# Prepare image
|
| 172 |
+
if image.mode != "RGB":
|
| 173 |
+
image = image.convert("RGB")
|
| 174 |
+
|
| 175 |
+
tensor = self.transform(image).unsqueeze(0).to(self.device)
|
| 176 |
+
|
| 177 |
+
# Forward pass
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
logits = self.model(tensor)
|
| 180 |
+
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
|
| 181 |
+
|
| 182 |
+
# Get predictions
|
| 183 |
+
class_idx = int(np.argmax(probs))
|
| 184 |
+
confidence = float(probs[class_idx])
|
| 185 |
+
prob_dict = {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}
|
| 186 |
+
|
| 187 |
+
return self.class_names[class_idx], confidence, prob_dict
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"[BasicCNN] Error during prediction: {e}")
|
| 191 |
+
raise
|
model_handlers/hugging_face_handler.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Tuple, Dict
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class HuggingFaceModel:
|
| 10 |
+
|
| 11 |
+
def __init__(self, model_dir: str):
|
| 12 |
+
self.model_dir = model_dir
|
| 13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
self.model = None
|
| 15 |
+
self.processor = None
|
| 16 |
+
self.class_names = []
|
| 17 |
+
|
| 18 |
+
print(f"[HuggingFace] Using device: {self.device}")
|
| 19 |
+
self._load_model()
|
| 20 |
+
|
| 21 |
+
def _load_model(self):
|
| 22 |
+
try:
|
| 23 |
+
# Load model and processor
|
| 24 |
+
self.model = AutoModelForImageClassification.from_pretrained(self.model_dir)
|
| 25 |
+
self.processor = AutoImageProcessor.from_pretrained(self.model_dir)
|
| 26 |
+
|
| 27 |
+
# Move to device
|
| 28 |
+
self.model.to(self.device).eval()
|
| 29 |
+
|
| 30 |
+
# Get class names from model config
|
| 31 |
+
self.class_names = list(self.model.config.id2label.values())
|
| 32 |
+
|
| 33 |
+
print(f"[HuggingFace] Model loaded successfully. Classes: {self.class_names}")
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"[HuggingFace] Error loading model: {e}")
|
| 37 |
+
raise
|
| 38 |
+
|
| 39 |
+
def _preprocess_image(self, img: Image.Image) -> Dict:
|
| 40 |
+
inputs = self.processor(images=img, return_tensors='pt')
|
| 41 |
+
return {k: v.to(self.device) for k, v in inputs.items()}
|
| 42 |
+
|
| 43 |
+
def predict(self, image: Image.Image) -> Tuple[str, float, Dict[str, float]]:
|
| 44 |
+
|
| 45 |
+
if image is None:
|
| 46 |
+
return "No image provided", 0.0, {}
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# Ensure image is PIL Image
|
| 50 |
+
if not isinstance(image, Image.Image):
|
| 51 |
+
image = Image.fromarray(image)
|
| 52 |
+
|
| 53 |
+
# Preprocess image
|
| 54 |
+
inputs = self._preprocess_image(image)
|
| 55 |
+
|
| 56 |
+
# Forward pass
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
outputs = self.model(**inputs)
|
| 59 |
+
logits = outputs.logits
|
| 60 |
+
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
|
| 61 |
+
|
| 62 |
+
# Get predictions
|
| 63 |
+
class_idx = int(np.argmax(probs))
|
| 64 |
+
confidence = float(probs[class_idx])
|
| 65 |
+
prob_dict = {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}
|
| 66 |
+
|
| 67 |
+
return self.class_names[class_idx], confidence, prob_dict
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"[HuggingFace] Error during prediction: {e}")
|
| 71 |
+
raise
|
model_handlers/xception_handler.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Tuple, Dict
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import timm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class XceptionModel:
|
| 12 |
+
|
| 13 |
+
# Class names must match training
|
| 14 |
+
CLASS_NAMES = ["Auto Rickshaws", "Bikes", "Cars", "Motorcycles", "Planes", "Ships", "Trains"]
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_dir: str, model_file: str = "best_model_finetuned_full.pt"):
|
| 17 |
+
self.model_dir = model_dir
|
| 18 |
+
self.model_file = model_file
|
| 19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
self.model = None
|
| 21 |
+
self.inference_transform = None
|
| 22 |
+
self.class_names = self.CLASS_NAMES
|
| 23 |
+
|
| 24 |
+
print(f"[Xception] Using device: {self.device}")
|
| 25 |
+
print(f"[Xception] Classes: {self.class_names}")
|
| 26 |
+
self._load_model()
|
| 27 |
+
|
| 28 |
+
def _load_model(self):
|
| 29 |
+
try:
|
| 30 |
+
model_path = os.path.join(self.model_dir, self.model_file)
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(model_path):
|
| 33 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 34 |
+
|
| 35 |
+
# Disable TorchDynamo (avoids CatchErrorsWrapper issues)
|
| 36 |
+
torch._dynamo.config.suppress_errors = True
|
| 37 |
+
torch._dynamo.reset()
|
| 38 |
+
|
| 39 |
+
# Load the model
|
| 40 |
+
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
| 41 |
+
|
| 42 |
+
num_classes = len(self.CLASS_NAMES)
|
| 43 |
+
|
| 44 |
+
if isinstance(checkpoint, dict) and not hasattr(checkpoint, "forward"):
|
| 45 |
+
# State dict: rebuild the model architecture used during training
|
| 46 |
+
model = timm.create_model("xception", pretrained=False, num_classes=num_classes)
|
| 47 |
+
in_features = model.get_classifier().in_features
|
| 48 |
+
model.fc = nn.Sequential(
|
| 49 |
+
nn.Linear(in_features, 512),
|
| 50 |
+
nn.ReLU(),
|
| 51 |
+
nn.Dropout(0.5),
|
| 52 |
+
nn.Linear(512, num_classes),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
state_dict = checkpoint
|
| 56 |
+
if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
|
| 57 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
| 58 |
+
|
| 59 |
+
model.load_state_dict(state_dict)
|
| 60 |
+
else:
|
| 61 |
+
# Full model
|
| 62 |
+
model = checkpoint
|
| 63 |
+
if hasattr(model, "_orig_mod"):
|
| 64 |
+
model = model._orig_mod
|
| 65 |
+
|
| 66 |
+
# Move model to device and set to evaluation mode
|
| 67 |
+
self.model = model.to(self.device).eval()
|
| 68 |
+
|
| 69 |
+
# Load preprocessing transforms
|
| 70 |
+
data_config = timm.data.resolve_model_data_config(self.model)
|
| 71 |
+
self.inference_transform = timm.data.create_transform(**data_config, is_training=False)
|
| 72 |
+
|
| 73 |
+
print(f"[Xception] Model loaded successfully from {model_path}")
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"[Xception] Error loading model: {e}")
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
def _preprocess_image(self, img: Image.Image) -> torch.Tensor:
|
| 80 |
+
img = img.convert("RGB")
|
| 81 |
+
tensor = self.inference_transform(img).unsqueeze(0).to(self.device)
|
| 82 |
+
return tensor
|
| 83 |
+
|
| 84 |
+
def predict(self, image: Image.Image) -> Tuple[str, float, Dict[str, float]]:
|
| 85 |
+
if image is None:
|
| 86 |
+
return "No image provided", 0.0, {}
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# Ensure image is PIL Image
|
| 90 |
+
if not isinstance(image, Image.Image):
|
| 91 |
+
image = Image.fromarray(image)
|
| 92 |
+
|
| 93 |
+
# Preprocess image
|
| 94 |
+
inputs = self._preprocess_image(image)
|
| 95 |
+
|
| 96 |
+
# Forward pass
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
outputs = self.model(inputs)
|
| 99 |
+
probs = F.softmax(outputs, dim=-1).cpu().numpy()[0]
|
| 100 |
+
|
| 101 |
+
# Get predictions
|
| 102 |
+
class_idx = int(np.argmax(probs))
|
| 103 |
+
confidence = float(probs[class_idx])
|
| 104 |
+
prob_dict = {self.class_names[i]: float(probs[i]) for i in range(len(self.class_names))}
|
| 105 |
+
|
| 106 |
+
return self.class_names[class_idx], confidence, prob_dict
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"[Xception] Error during prediction: {e}")
|
| 110 |
+
raise
|
models/basic_cnn/cnn_model_statedict_20260226_034332.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d646f75b07abffa476beeaa513f50b81904afab58cb978cee12175b1d9ce5c12
|
| 3 |
+
size 2110727
|
models/basic_cnn/deployment_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_state_dict": "cnn_model_statedict_20260226_034332.pth",
|
| 3 |
+
"metadata": "model_metadata_20260226_034332.json",
|
| 4 |
+
"label_encoder": "label_encoder_20260226_034332.pkl",
|
| 5 |
+
"input_size": [
|
| 6 |
+
224,
|
| 7 |
+
224
|
| 8 |
+
],
|
| 9 |
+
"batch_size": 32,
|
| 10 |
+
"device": "cuda"
|
| 11 |
+
}
|
models/basic_cnn/model_metadata_20260226_034332.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_architecture": "PowerfulCNN",
|
| 3 |
+
"num_classes": 7,
|
| 4 |
+
"input_shape": [
|
| 5 |
+
3,
|
| 6 |
+
224,
|
| 7 |
+
224
|
| 8 |
+
],
|
| 9 |
+
"normalization_mean": [
|
| 10 |
+
0.46052455367479384,
|
| 11 |
+
0.5079975089484482,
|
| 12 |
+
0.5388914703636689
|
| 13 |
+
],
|
| 14 |
+
"normalization_std": [
|
| 15 |
+
0.2887674684678098,
|
| 16 |
+
0.2696178694962567,
|
| 17 |
+
0.2943129167380753
|
| 18 |
+
],
|
| 19 |
+
"class_names": {
|
| 20 |
+
"0": "Auto Rickshaws",
|
| 21 |
+
"1": "Bikes",
|
| 22 |
+
"2": "Cars",
|
| 23 |
+
"3": "Motorcycles",
|
| 24 |
+
"4": "Planes",
|
| 25 |
+
"5": "Ships",
|
| 26 |
+
"6": "Trains"
|
| 27 |
+
},
|
| 28 |
+
"training_device": "cuda",
|
| 29 |
+
"saved_timestamp": "20260226_034332",
|
| 30 |
+
"model_size_mb": 2.012946128845215
|
| 31 |
+
}
|
models/hugging_face/config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ViTForImageClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.0,
|
| 6 |
+
"dtype": "float32",
|
| 7 |
+
"encoder_stride": 16,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_dropout_prob": 0.0,
|
| 10 |
+
"hidden_size": 192,
|
| 11 |
+
"id2label": {
|
| 12 |
+
"0": "Auto Rickshaws",
|
| 13 |
+
"1": "Bikes",
|
| 14 |
+
"2": "Cars",
|
| 15 |
+
"3": "Motorcycles",
|
| 16 |
+
"4": "Planes",
|
| 17 |
+
"5": "Ships",
|
| 18 |
+
"6": "Trains"
|
| 19 |
+
},
|
| 20 |
+
"image_size": 224,
|
| 21 |
+
"initializer_range": 0.02,
|
| 22 |
+
"intermediate_size": 768,
|
| 23 |
+
"label2id": {
|
| 24 |
+
"Auto Rickshaws": 0,
|
| 25 |
+
"Bikes": 1,
|
| 26 |
+
"Cars": 2,
|
| 27 |
+
"Motorcycles": 3,
|
| 28 |
+
"Planes": 4,
|
| 29 |
+
"Ships": 5,
|
| 30 |
+
"Trains": 6
|
| 31 |
+
},
|
| 32 |
+
"layer_norm_eps": 1e-12,
|
| 33 |
+
"model_type": "vit",
|
| 34 |
+
"num_attention_heads": 3,
|
| 35 |
+
"num_channels": 3,
|
| 36 |
+
"num_hidden_layers": 12,
|
| 37 |
+
"patch_size": 16,
|
| 38 |
+
"pooler_act": "tanh",
|
| 39 |
+
"pooler_output_size": 192,
|
| 40 |
+
"problem_type": "single_label_classification",
|
| 41 |
+
"qkv_bias": true,
|
| 42 |
+
"transformers_version": "5.0.0",
|
| 43 |
+
"use_cache": false
|
| 44 |
+
}
|
models/hugging_face/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:891a3189e95ea1986818fd118084da21fb73369aebcacc5f1f50171354a20242
|
| 3 |
+
size 22125780
|
models/hugging_face/preprocessor_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_convert_rgb": null,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_rescale": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
+
0.5,
|
| 8 |
+
0.5,
|
| 9 |
+
0.5
|
| 10 |
+
],
|
| 11 |
+
"image_processor_type": "ViTImageProcessor",
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.5,
|
| 14 |
+
0.5,
|
| 15 |
+
0.5
|
| 16 |
+
],
|
| 17 |
+
"resample": 2,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 224,
|
| 21 |
+
"width": 224
|
| 22 |
+
}
|
| 23 |
+
}
|
models/xception/best_model_finetuned_full.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d08dd051737336a114b4bf8e73b1d3e6399285d9d1dee6b1d5a3e85b3066db7
|
| 3 |
+
size 87820811
|
requirements.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch Unified Model Comparison
|
| 2 |
+
# Combined dependencies for all three models
|
| 3 |
+
|
| 4 |
+
# Web Framework & UI
|
| 5 |
+
gradio
|
| 6 |
+
fastapi
|
| 7 |
+
uvicorn
|
| 8 |
+
python-multipart
|
| 9 |
+
|
| 10 |
+
# Core ML Framework
|
| 11 |
+
torch
|
| 12 |
+
torchvision
|
| 13 |
+
|
| 14 |
+
# Hugging Face & Transformers
|
| 15 |
+
transformers
|
| 16 |
+
safetensors
|
| 17 |
+
datasets
|
| 18 |
+
evaluate
|
| 19 |
+
|
| 20 |
+
# Image Processing & Numerical Computing
|
| 21 |
+
pillow
|
| 22 |
+
numpy
|
| 23 |
+
opencv-python
|
| 24 |
+
|
| 25 |
+
# Vision Models Library (for Xception)
|
| 26 |
+
timm
|
| 27 |
+
|
| 28 |
+
# Additional utilities
|
| 29 |
+
pydantic
|
| 30 |
+
huggingface_hub
|
| 31 |
+
|
| 32 |
+
# Optional: For enhanced performance
|
| 33 |
+
scikit-learn
|