LookThem_Tiny-ImageNet / V5CleanedCode.md
ASomeoneWhoInterestedWithAI's picture
Create V5CleanedCode.md
c1c7652 verified
|
Raw
History Blame Contribute Delete
19.9 kB

Cleaned Code

import os
import math
import zipfile
import urllib.request

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# =========================================================
# 1. TINY-IMAGENET DOWNLOAD + PREPARATION
# =========================================================

def prepare_tiny_imagenet():
    """
    Downloads and extracts Tiny-ImageNet if not already present.

    Returns:
        train_dir, val_dir
    """

    dataset_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"

    zip_path = "./tiny-imagenet-200.zip"

    extract_path = "./tiny-imagenet-200"

    # -----------------------------------------------------
    # Download dataset archive
    # -----------------------------------------------------
    if not os.path.exists(zip_path):

        print(
            "Downloading Tiny-ImageNet (~230MB)... "
            "Please wait..."
        )

        urllib.request.urlretrieve(
            dataset_url,
            zip_path
        )

        print("Download complete!")

    # -----------------------------------------------------
    # Extract dataset archive
    # -----------------------------------------------------
    if not os.path.exists(extract_path):

        print("Extracting dataset...")

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall("./")

        print("Extraction complete!")

    return (
        os.path.join(extract_path, "train"),
        os.path.join(extract_path, "val")
    )


train_dir, val_dir = prepare_tiny_imagenet()


# =========================================================
# 2. VALIDATION FOLDER RESTRUCTURING
# =========================================================
#
# Tiny-ImageNet validation images are originally placed
# in a single shared folder.
#
# This section reorganizes them into class-specific
# folders so torchvision.datasets.ImageFolder can
# load them correctly.
#

val_img_dir = "./tiny-imagenet-200/val/images"

val_annotations = (
    "./tiny-imagenet-200/val/val_annotations.txt"
)

if os.path.exists(val_img_dir):

    print(
        "Reorganizing Tiny-ImageNet validation "
        "folder structure..."
    )

    with open(val_annotations, "r") as f:
        lines = f.readlines()

    for line in lines:

        parts = line.strip().split("\t")

        img_name = parts[0]
        class_name = parts[1]

        class_dir = os.path.join(
            "./tiny-imagenet-200/val",
            class_name
        )

        os.makedirs(class_dir, exist_ok=True)

        src_path = os.path.join(
            val_img_dir,
            img_name
        )

        dst_path = os.path.join(
            class_dir,
            img_name
        )

        if os.path.exists(src_path):
            os.rename(src_path, dst_path)

    os.rmdir(val_img_dir)

    print(
        "Validation folder restructuring complete!"
    )


# =========================================================
# 3. DATA AUGMENTATION + NORMALIZATION
# =========================================================

transform_train = transforms.Compose([

    # Horizontal augmentation
    transforms.RandomHorizontalFlip(),

    # Mild rotational augmentation
    transforms.RandomRotation(15),

    transforms.ToTensor(),

    # Tiny-ImageNet normalization statistics
    transforms.Normalize(
        (0.4802, 0.4481, 0.3975),
        (0.2302, 0.2265, 0.2262)
    )
])

transform_val = transforms.Compose([

    transforms.ToTensor(),

    transforms.Normalize(
        (0.4802, 0.4481, 0.3975),
        (0.2302, 0.2265, 0.2262)
    )
])


# =========================================================
# 4. DATASET + DATALOADER SETUP
# =========================================================

train_dataset = datasets.ImageFolder(
    root=train_dir,
    transform=transform_train
)

val_dataset = datasets.ImageFolder(
    root=val_dir,
    transform=transform_val
)

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


# =========================================================
# 5. CORE RELATIONAL LAYER β€” LOOKTHEM LAYER
# =========================================================

class LookThemLayer(nn.Module):
    """
    Token-relational processing layer.

    Each token owns two independent micro-networks
    whose outputs are compared against every other
    token using ratio-based relational interactions.

    The interaction maps are transformed and then
    redistributed back into token-space.
    """

    def __init__(self, num_tokens, in_features, hidden_dim):

        super(LookThemLayer, self).__init__()

        self.num_tokens = num_tokens
        self.in_features = in_features

        # =================================================
        # BRANCH 1 PARAMETERS
        # =================================================
        self.mod1_w1 = nn.Parameter(
            torch.randn(
                num_tokens,
                in_features,
                hidden_dim
            )
        )

        self.mod1_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod1_w2 = nn.Parameter(
            torch.randn(
                num_tokens,
                hidden_dim,
                1
            )
        )

        self.mod1_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # =================================================
        # BRANCH 2 PARAMETERS
        # =================================================
        self.mod2_w1 = nn.Parameter(
            torch.randn(
                num_tokens,
                in_features,
                hidden_dim
            )
        )

        self.mod2_b1 = nn.Parameter(
            torch.zeros(num_tokens, hidden_dim)
        )

        self.mod2_w2 = nn.Parameter(
            torch.randn(
                num_tokens,
                hidden_dim,
                1
            )
        )

        self.mod2_b2 = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        # =================================================
        # RELATIONAL TRANSFORMATION PARAMETERS
        # =================================================
        self.trans_w = nn.Parameter(
            torch.randn(num_tokens, 1, 1)
        )

        self.trans_b = nn.Parameter(
            torch.zeros(num_tokens, 1)
        )

        self._init_weights()

    def _init_weights(self):
        """
        Kaiming initialization for all learnable
        projection matrices.
        """

        for w in [
            self.mod1_w1,
            self.mod2_w1,
            self.mod1_w2,
            self.mod2_w2,
            self.trans_w
        ]:
            nn.init.kaiming_uniform_(
                w,
                a=math.sqrt(5)
            )

    def forward(self, x):
        """
        Input shape:
            [B, Tokens, Features]

        Output shape:
            [B, Tokens, Features]
        """

        N = self.num_tokens

        # =================================================
        # BRANCH 1 FORWARD PASS
        # =================================================
        h1 = (
            torch.einsum(
                'bti,tij->btj',
                x,
                self.mod1_w1
            )
            + self.mod1_b1
        )

        out_m1 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h1),
                self.mod1_w2
            )
            + self.mod1_b2
        )

        # =================================================
        # BRANCH 2 FORWARD PASS
        # =================================================
        h2 = (
            torch.einsum(
                'bti,tij->btj',
                x,
                self.mod2_w1
            )
            + self.mod2_b1
        )

        out_m2 = (
            torch.einsum(
                'btj,tjk->btk',
                F.gelu(h2),
                self.mod2_w2
            )
            + self.mod2_b2
        )

        # Numerical stabilization
        out_m2_safe = out_m2 + 1e-5

        # =================================================
        # PAIRWISE TOKEN RELATIONAL COMPARISON
        # =================================================

        compare = torch.tanh(
            out_m1.unsqueeze(2) /
            out_m2_safe.unsqueeze(1)
        )

        compare2 = torch.tanh(
            out_m1.unsqueeze(1) /
            out_m2_safe.unsqueeze(2)
        )

        # =================================================
        # RELATIONAL MAP TRANSFORMATION
        # =================================================
        bias_reshaped = self.trans_b.view(
            1,
            1,
            N,
            1
        )

        trans_compare = (
            torch.einsum(
                'bije,jef->bijf',
                compare,
                self.trans_w
            )
            + bias_reshaped
        )

        trans_compare2 = (
            torch.einsum(
                'bije,jef->bijf',
                compare2,
                self.trans_w
            )
            + bias_reshaped
        )

        # =================================================
        # BIDIRECTIONAL INTERACTION FUSION
        # =================================================
        interaction = (
            trans_compare * x.unsqueeze(2)
            + trans_compare2 * x.unsqueeze(1)
        ) / 2

        # Remove self-interaction
        mask = 1.0 - torch.eye(
            N,
            device=x.device
        )

        interaction_masked = (
            interaction *
            mask.view(1, N, N, 1)
        )

        # Aggregate external token interactions
        return (
            interaction_masked.sum(dim=2)
            / (N - 1.0)
        )


# =========================================================
# 6. MAIN ARCHITECTURE β€” LOOKTHEM V5
# =========================================================

class LookThemV5(nn.Module):
    """
    Dual-stream asymmetric relational architecture.

    Stream A:
        High-resolution grayscale macro-structure stream.

    Stream B:
        RGB color-essence stream compressed into
        lower spatial resolution.

    Both streams are fused at feature-level and
    processed through the relational LookThem core.
    """

    def __init__(self):

        super(LookThemV5, self).__init__()

        # =================================================
        # RGB β†’ GRAYSCALE CONVERSION WEIGHTS
        # =================================================
        self.register_buffer(
            'grayscale_weights',
            torch.tensor(
                [0.299, 0.587, 0.114]
            ).view(1, 3, 1, 1)
        )

        # =================================================
        # STREAM A β€” MACRO STRUCTURE STREAM
        # =================================================
        #
        # Preserves higher spatial resolution (16x16)
        # to retain broader structural information.
        #
        self.stream_a = nn.Sequential(

            nn.Conv2d(
                1,
                16,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16,
                32,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU()
        )

        # =================================================
        # TOKEN BRIDGE
        # =================================================
        #
        # Compresses spatial dimension:
        #
        #   256 spatial positions β†’ 64 tokens
        #
        # while preserving feature channels.
        #
        self.token_bridge = nn.Linear(256, 64)

        # =================================================
        # STREAM B β€” COLOR ESSENCE STREAM
        # =================================================
        #
        # RGB stream reduced into 8x8 spatial layout
        # using pure stride-based standard convolutions.
        #
        self.stream_b = nn.Sequential(

            nn.Conv2d(
                3,
                16,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(16),
            nn.GELU(),

            nn.Conv2d(
                16,
                32,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(
                32,
                32,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(32),
            nn.GELU()
        )

        # =================================================
        # RELATIONAL COGNITION CORE
        # =================================================
        self.lookthem = LookThemLayer(
            num_tokens=64,
            in_features=64,
            hidden_dim=32
        )

        # =================================================
        # CLASSIFIER HEAD
        # =================================================
        #
        # Flattened relational token representation
        # followed by lightweight anti-overfit head.
        #
        self.classifier = nn.Sequential(

            nn.Flatten(),

            nn.Linear(64 * 64, 256),

            nn.ReLU(),

            nn.Dropout(0.4),

            nn.Linear(256, 200)
        )

    def forward(self, x):

        batch_size = x.size(0)

        # =================================================
        # STREAM A β€” GRAYSCALE MACRO EXTRACTION
        # =================================================

        # Convert RGB image into grayscale
        x_gray = torch.sum(
            x * self.grayscale_weights,
            dim=1,
            keepdim=True
        )

        feat_a = self.stream_a(x_gray)

        # Shape:
        # [B, 32, 16, 16]

        feat_a_flat = feat_a.view(
            batch_size,
            32,
            256
        )

        # Spatial compression:
        # 256 β†’ 64 tokens
        feat_a_compressed = self.token_bridge(
            feat_a_flat
        )

        feat_a_tokens = (
            feat_a_compressed.transpose(1, 2)
        )

        # Final shape:
        # [B, 64 Tokens, 32 Features]

        # =================================================
        # STREAM B β€” RGB COLOR EXTRACTION
        # =================================================

        feat_b = self.stream_b(x)

        feat_b_tokens = (
            feat_b
            .view(batch_size, 32, 64)
            .transpose(1, 2)
        )

        # Final shape:
        # [B, 64 Tokens, 32 Features]

        # =================================================
        # ASYMMETRIC FEATURE FUSION
        # =================================================
        #
        # Token count remains fixed while
        # feature dimensionality is doubled.
        #
        tokens_combined = torch.cat(
            [feat_a_tokens, feat_b_tokens],
            dim=2
        )

        # Final shape:
        # [B, 64 Tokens, 64 Features]

        # =================================================
        # RELATIONAL COGNITION
        # =================================================
        out_lookthem = self.lookthem(
            tokens_combined
        )

        # =================================================
        # CLASSIFICATION
        # =================================================
        return self.classifier(out_lookthem)


# =========================================================
# 7. TRAINING RUNTIME + CHECKPOINT SYSTEM
# =========================================================

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

model = LookThemV5().to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=20
)

start_epoch = 0

checkpoint_path = "lookthem_v5_checkpoint.pth"


# =========================================================
# CHECKPOINT RESUME
# =========================================================

if os.path.exists(checkpoint_path):

    print(
        "Checkpoint detected. "
        "Resuming previous experiment..."
    )

    checkpoint = torch.load(checkpoint_path)

    model.load_state_dict(
        checkpoint['model_state_dict']
    )

    optimizer.load_state_dict(
        checkpoint['optimizer_state_dict']
    )

    scheduler.load_state_dict(
        checkpoint['scheduler_state_dict']
    )

    start_epoch = checkpoint['epoch']

    print(
        f"Successfully resumed from "
        f"epoch {start_epoch + 1}"
    )


print(
    f"Starting LookThem V5 "
    f"(Asymmetric Fusion) on {device}..."
)


# =========================================================
# 8. TRAINING LOOP
# =========================================================

for epoch in range(start_epoch, 20):

    model.train()

    total_loss = 0
    correct = 0
    total = 0

    for data, target in train_loader:

        data = data.to(device)

        target = target.to(device)

        optimizer.zero_grad()

        output = model(data)

        loss = criterion(output, target)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        _, predicted = output.max(1)

        total += target.size(0)

        correct += predicted.eq(target).sum().item()

    scheduler.step()

    acc = 100. * correct / total

    current_lr = optimizer.param_groups[0]['lr']

    print(
        f"Epoch {epoch+1:02d}/20 | "
        f"Train Loss: "
        f"{total_loss / len(train_loader):.4f} | "
        f"Train Acc: {acc:.2f}% | "
        f"LR: {current_lr:.6f}"
    )

    # -----------------------------------------------------
    # Periodic checkpoint save
    # -----------------------------------------------------
    if (epoch + 1) % 5 == 0:

        torch.save({

            'epoch': epoch + 1,

            'model_state_dict':
                model.state_dict(),

            'optimizer_state_dict':
                optimizer.state_dict(),

            'scheduler_state_dict':
                scheduler.state_dict(),

        }, checkpoint_path)

        print(
            f"[CHECKPOINT] "
            f"Epoch {epoch+1} saved successfully."
        )


# =========================================================
# 9. FINAL VALIDATION
# =========================================================

model.eval()

test_loss = 0
test_correct = 0
test_total = 0

print("\nStarting final validation...")

with torch.no_grad():

    for data, target in val_loader:

        data = data.to(device)

        target = target.to(device)

        output = model(data)

        loss = criterion(output, target)

        test_loss += loss.item()

        _, predicted = output.max(1)

        test_total += target.size(0)

        test_correct += predicted.eq(target).sum().item()

final_test_acc = (
    100. * test_correct / test_total
)

print("=== FINAL LOOKTHEM V5 RESULTS ===")

print(
    f"Test Loss: "
    f"{test_loss / len(val_loader):.4f} | "
    f"Test Accuracy: {final_test_acc:.2f}%"
)

# Save final trained weights
torch.save(
    model.state_dict(),
    "LookThem_V5_Final.pth"
)

print(
    f"Training complete! "
    f"Final model size: "
    f"{os.path.getsize('LookThem_V5_Final.pth') / (1024*1024):.2f} MB"
)