| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import warnings |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from PIL import Image, ImageDraw, ImageFont |
| | from sklearn.cluster import KMeans |
| | from sklearn.decomposition import PCA |
| | from transformers import ( |
| | AutoImageProcessor, |
| | ViTModel, |
| | ViTForImageClassification, |
| | AutoConfig, |
| | ) |
| | import plotly.express as px |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | MODEL_NAME = "google/vit-base-patch16-224" |
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | BASE_MODEL = None |
| | CLF_MODEL = None |
| | PROCESSOR = None |
| |
|
| |
|
| | |
| | def load_models(): |
| | global BASE_MODEL, CLF_MODEL, PROCESSOR |
| | if BASE_MODEL is not None and CLF_MODEL is not None and PROCESSOR is not None: |
| | return BASE_MODEL, CLF_MODEL, PROCESSOR |
| |
|
| | PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| |
|
| | |
| | cfg = AutoConfig.from_pretrained(MODEL_NAME) |
| | cfg.attn_implementation = "eager" |
| | cfg.output_attentions = True |
| | cfg.output_hidden_states = True |
| |
|
| | |
| | BASE_MODEL = ViTModel.from_pretrained(MODEL_NAME, config=cfg) |
| | BASE_MODEL.to(DEVICE).eval() |
| |
|
| | |
| | CLF_MODEL = ViTForImageClassification.from_pretrained(MODEL_NAME) |
| | CLF_MODEL.to(DEVICE).eval() |
| |
|
| | return BASE_MODEL, CLF_MODEL, PROCESSOR |
| |
|
| |
|
| | |
| | def patch_grid_info(image_size: int = 224, patch_size: int = 16): |
| | grid_size = image_size // patch_size |
| | positions = [] |
| | for i in range(grid_size): |
| | for j in range(grid_size): |
| | |
| | cx = int((j + 0.5) * patch_size) |
| | cy = int((i + 0.5) * patch_size) |
| | positions.append((cx, cy)) |
| | return grid_size, positions |
| |
|
| |
|
| | |
| | def draw_patch_grid(img: Image.Image, patch_size: int = 16, outline=(0, 180, 0)) -> Image.Image: |
| | img = img.convert("RGB").resize((224, 224)) |
| | draw = ImageDraw.Draw(img) |
| | w, h = img.size |
| | for x in range(0, w, patch_size): |
| | draw.line([(x, 0), (x, h)], fill=outline, width=1) |
| | for y in range(0, h, patch_size): |
| | draw.line([(0, y), (w, y)], fill=outline, width=1) |
| | return img |
| |
|
| |
|
| | def draw_cluster_blocks(img: Image.Image, labels: np.ndarray, n_clusters: int = 4, patch_size: int = 16): |
| | """ |
| | labels: (n_patches,) cluster labels assigned to each patch index (left→right, top→bottom) |
| | """ |
| | img = img.convert("RGB").resize((224, 224)) |
| | draw = ImageDraw.Draw(img, "RGBA") |
| | grid_size, positions = patch_grid_info() |
| | colors = [ |
| | (255, 99, 71, 140), |
| | (60, 179, 113, 140), |
| | (65, 105, 225, 140), |
| | (255, 215, 0, 140), |
| | (199, 21, 133, 140), |
| | (0, 206, 209, 140), |
| | ] |
| | for idx, lab in enumerate(labels): |
| | i = idx // grid_size |
| | j = idx % grid_size |
| | x0 = j * patch_size |
| | y0 = i * patch_size |
| | x1 = x0 + patch_size |
| | y1 = y0 + patch_size |
| | col = colors[int(lab) % len(colors)] |
| | draw.rectangle([x0, y0, x1, y1], fill=col) |
| | return img |
| |
|
| |
|
| | def draw_attention_arrows(img: Image.Image, att_matrix: np.ndarray, top_k: int = 3, query_idx: Optional[int] = None): |
| | """ |
| | att_matrix: (n_patches, n_patches) attention from query->keys (already preprocessed) |
| | If query_idx is None -> use CLS (not plotted as patch), else 0..n_patches-1 |
| | We'll draw arrows from query patch centers to top-k key patch centers. |
| | """ |
| | img = img.convert("RGB").resize((224, 224)) |
| | draw = ImageDraw.Draw(img, "RGBA") |
| | grid_size, positions = patch_grid_info() |
| | |
| | if query_idx is None: |
| | query_idx = (grid_size * grid_size) // 2 |
| | qpos = positions[query_idx] |
| | |
| | vec = att_matrix[query_idx] |
| | top_idx = vec.argsort()[-top_k:][::-1] |
| | for t in top_idx: |
| | kpos = positions[t] |
| | |
| | draw.line([qpos, kpos], fill=(255, 0, 0, 200), width=3) |
| | |
| | dx = kpos[0] - qpos[0] |
| | dy = kpos[1] - qpos[1] |
| | ang = math.atan2(dy, dx) |
| | |
| | ah = 8 |
| | p1 = (kpos[0] - ah * math.cos(ang - 0.3), kpos[1] - ah * math.sin(ang - 0.3)) |
| | p2 = (kpos[0] - ah * math.cos(ang + 0.3), kpos[1] - ah * math.sin(ang + 0.3)) |
| | draw.polygon([kpos, p1, p2], fill=(255, 0, 0, 200)) |
| | |
| | r = 10 |
| | draw.ellipse([qpos[0] - r, qpos[1] - r, qpos[0] + r, qpos[1] + r], outline=(0, 0, 255, 220), width=2) |
| | return img |
| |
|
| |
|
| | def make_focus_overlay(img: Image.Image, heat_grid: np.ndarray, alpha: float = 0.6): |
| | """ |
| | heat_grid: (G,G) float map |
| | overlay colored transparency on image where heat is high |
| | """ |
| | img = img.convert("RGB").resize((224, 224)) |
| | g = np.array(heat_grid, dtype=np.float32) |
| | if np.any(g): |
| | g = g - g.min() |
| | if g.max() > 0: |
| | g = g / g.max() |
| | else: |
| | g = np.zeros_like(g) |
| | heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR) |
| | heat = np.array(heat_img).astype(np.float32) / 255.0 |
| | draw = ImageDraw.Draw(img, "RGBA") |
| | |
| | H, W = heat.shape |
| | for y in range(H): |
| | for x in range(W): |
| | v = heat[y, x] |
| | if v > 0.05: |
| | |
| | r = int(255 * v) |
| | gcol = int(200 * (1 - v)) |
| | draw.point((x, y), fill=(r, gcol, 40, int(255 * alpha * v))) |
| | return img |
| |
|
| |
|
| | |
| | def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray: |
| | avg_mats = [] |
| | for a in all_attentions: |
| | mat = a[0].mean(dim=0).detach().cpu().numpy() |
| | avg_mats.append(mat) |
| | seq = avg_mats[0].shape[0] |
| | aug = [] |
| | for A in avg_mats: |
| | A_hat = A + np.eye(seq) |
| | row_sums = A_hat.sum(axis=-1, keepdims=True) |
| | row_sums[row_sums == 0] = 1.0 |
| | A_hat = A_hat / row_sums |
| | aug.append(A_hat) |
| | R = aug[0] |
| | for A in aug[1:]: |
| | R = A @ R |
| | return R |
| |
|
| |
|
| | |
| | def pca_plot_from_hidden(hidden_states: List[torch.Tensor], layers: List[int]): |
| | pts_all = [] |
| | labels = [] |
| | for li in layers: |
| | hs = hidden_states[li][0].detach().cpu().numpy() |
| | patches = hs[1:, :] |
| | pca = PCA(n_components=2) |
| | pts = pca.fit_transform(patches) |
| | pts_all.append(pts) |
| | labels.append(np.array([li] * pts.shape[0])) |
| | coords = np.vstack(pts_all) |
| | layer_labels = np.concatenate(labels) |
| | df = {"x": coords[:, 0], "y": coords[:, 1], "layer": layer_labels.astype(str)} |
| | fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)") |
| | fig.update_traces(marker=dict(size=6)) |
| | fig.update_layout(height=480) |
| | return fig |
| |
|
| |
|
| | |
| | def analyze_all(img: Optional[Image.Image], mode_simple: bool): |
| | if img is None: |
| | |
| | empty = None |
| | return empty, empty, empty, empty, "", empty, empty, empty |
| |
|
| | base, clf, processor = load_models() |
| |
|
| | |
| | img224 = img.convert("RGB").resize((224, 224)) |
| | inputs = processor(images=img224, return_tensors="pt").to(DEVICE) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = base(**inputs) |
| |
|
| | attentions = outputs.attentions |
| | hidden_states = outputs.hidden_states |
| |
|
| | |
| | grid_size, positions = patch_grid_info() |
| | seq_len = attentions[0].shape[-1] |
| | n_patches = seq_len - 1 |
| |
|
| | |
| | patch_grid_img = draw_patch_grid(img224.copy()) |
| |
|
| | |
| | last_hidden = hidden_states[-1][0].detach().cpu().numpy() |
| | patch_embeddings = last_hidden[1:, :] |
| | |
| | n_clusters = 4 |
| | try: |
| | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(patch_embeddings) |
| | cluster_labels = kmeans.labels_ |
| | except Exception: |
| | |
| | cluster_labels = np.zeros(n_patches, dtype=int) |
| |
|
| | cluster_img = draw_cluster_blocks(img224.copy(), cluster_labels, n_clusters=n_clusters) |
| |
|
| | |
| | last_att = attentions[-1][0].mean(dim=0).cpu().numpy() |
| | |
| | |
| | |
| | if last_att.shape[0] >= n_patches + 1: |
| | patch_to_patch = last_att[1:, 1:] |
| | else: |
| | |
| | patch_to_patch = np.zeros((n_patches, n_patches)) |
| | |
| | arrow_img = draw_attention_arrows(img224.copy(), patch_to_patch, top_k=4, query_idx=(n_patches // 2)) |
| |
|
| | |
| | rollout = compute_attention_rollout(attentions) |
| | |
| | rollout_cls = rollout[0, 1:] |
| | if rollout_cls.shape[0] != grid_size * grid_size: |
| | tmp = np.zeros(grid_size * grid_size, dtype=float) |
| | nmin = min(len(rollout_cls), tmp.shape[0]) |
| | tmp[:nmin] = rollout_cls[:nmin] |
| | rollout_cls = tmp |
| | rollout_grid = rollout_cls.reshape(grid_size, grid_size) |
| | focus_img = make_focus_overlay(img224.copy(), rollout_grid, alpha=0.6) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = clf(**inputs).logits[0].cpu().numpy() |
| | probs = np.exp(logits - logits.max()) |
| | probs = probs / probs.sum() |
| | top5 = probs.argsort()[-5:][::-1] |
| | labels = clf.config.id2label |
| | preds_text = "\n".join([f"{labels[i]} — {probs[i]*100:.2f}%" for i in top5]) |
| |
|
| | |
| | pca_fig = pca_plot_from_hidden(hidden_states, [0, max(0, len(hidden_states) // 2), len(hidden_states) - 1]) |
| |
|
| | |
| | att_np = attentions[-1][0].cpu().numpy() |
| | |
| | cls_to_patches = att_np.mean(axis=0)[0, 1:] |
| | if cls_to_patches.shape[0] != grid_size * grid_size: |
| | tmp = np.zeros(grid_size * grid_size, dtype=float) |
| | nmin = min(len(cls_to_patches), tmp.shape[0]) |
| | tmp[:nmin] = cls_to_patches[:nmin] |
| | cls_to_patches = tmp |
| | cls_grid = cls_to_patches.reshape(grid_size, grid_size) |
| | |
| | from PIL import Image |
| | focus_overlay_default = make_focus_overlay(img224.copy(), cls_grid, alpha=0.5) |
| |
|
| | |
| | state = { |
| | "attentions": [a.cpu() for a in attentions], |
| | "hidden_states": [h.cpu() for h in hidden_states], |
| | "grid_size": grid_size, |
| | "num_layers": len(attentions), |
| | "num_heads": attentions[0].shape[1], |
| | "base_image": img, |
| | } |
| |
|
| | |
| | |
| | |
| | simple_explain = """ |
| | **How ViT Sees — Simple Steps** |
| | |
| | 1) **Chop** — The image is chopped into small square tiles (patches) like LEGO pieces. |
| | 2) **Understand** — Each piece gets a code that describes colors/edges. Pieces that look similar are grouped. |
| | 3) **Talk** — Pieces tell each other what they see (we draw arrows to show that). |
| | 4) **Focus & Guess** — The model merges clues and focuses on important areas, then guesses what the image shows. |
| | """ |
| |
|
| | advanced_explain = """ |
| | **Advanced View:** Explore attention per layer/head, the PCA of patch embeddings, and the model's internal focus. |
| | Use sliders to change layer/head and see how ViT's attention evolves. |
| | """ |
| |
|
| | return ( |
| | patch_grid_img, |
| | cluster_img, |
| | arrow_img, |
| | focus_img, |
| | preds_text, |
| | simple_explain, |
| | focus_overlay_default, |
| | pca_fig, |
| | preds_text, |
| | advanced_explain, |
| | state, |
| | ) |
| |
|
| |
|
| | |
| | def advanced_update_attention(state: Dict[str, Any], layer_idx: int, head_idx: int): |
| | if not state: |
| | return None |
| | l = max(0, min(int(layer_idx), state["num_layers"] - 1)) |
| | h = max(0, min(int(head_idx), state["num_heads"] - 1)) |
| | att_tensor = state["attentions"][l] |
| | if att_tensor.ndim == 4: |
| | att_tensor = att_tensor[0] |
| | att_np = att_tensor.numpy() |
| | |
| | vec = att_np[h, 0, 1:] |
| | grid = state["grid_size"] |
| | if vec.shape[0] != grid * grid: |
| | tmp = np.zeros(grid * grid, dtype=float) |
| | nmin = min(vec.shape[0], tmp.shape[0]) |
| | tmp[:nmin] = vec[:nmin] |
| | vec = tmp |
| | grid_map = vec.reshape(grid, grid) |
| | return make_focus_overlay(state["base_image"].convert("RGB"), grid_map, alpha=0.55) |
| |
|
| |
|
| | def advanced_update_rollout(state: Dict[str, Any]): |
| | if not state: |
| | return None |
| | mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in state["attentions"]] |
| | R = compute_attention_rollout(mats) |
| | grid = state["grid_size"] |
| | rollout_cls = R[0, 1:] |
| | if rollout_cls.shape[0] != grid * grid: |
| | tmp = np.zeros(grid * grid, dtype=float) |
| | nmin = min(len(rollout_cls), tmp.shape[0]) |
| | tmp[:nmin] = rollout_cls[:nmin] |
| | rollout_cls = tmp |
| | rollout_grid = rollout_cls.reshape(grid, grid) |
| | return make_focus_overlay(state["base_image"].convert("RGB"), rollout_grid, alpha=0.6) |
| |
|
| |
|
| | def advanced_update_pca(state: Dict[str, Any], txt: str): |
| | if not state: |
| | return None |
| | try: |
| | layers = [int(x.strip()) for x in txt.split(",") if x.strip() != ""] |
| | except Exception: |
| | layers = [0, max(0, state["num_layers"] - 1)] |
| | return pca_plot_from_hidden(state["hidden_states"], layers) |
| |
|
| |
|
| | |
| | with gr.Blocks(title="ViT Visualizer — Simple + Advanced") as demo: |
| | gr.Markdown("# 👀 How Vision Transformers (ViT) See Images\n" |
| | "Simple mode (story-style) + Advanced mode (inspect internals). Model: **google/vit-base-patch16-224**") |
| |
|
| | with gr.Tabs(): |
| | with gr.TabItem("Simple (for everyone)"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | img_input = gr.Image(label="Upload an image (photo / object)", type="pil") |
| | run_btn = gr.Button("🔎 Explain simply") |
| | gr.Markdown("Tip: use clear images of objects, animals, scenes for best examples.") |
| | with gr.Column(scale=1): |
| | pass |
| |
|
| | gr.Markdown("### Step 1 — Chopped into patches") |
| | step1 = gr.Image(label="Patch Grid (ViT chops image into 16×16 patches)") |
| |
|
| | gr.Markdown("### Step 2 — The model groups similar patches") |
| | step2 = gr.Image(label="Clustered patches (colored blocks)") |
| |
|
| | gr.Markdown("### Step 3 — Patches talk to each other (simplified)") |
| | step3 = gr.Image(label="Patch-to-Patch arrows") |
| |
|
| | gr.Markdown("### Step 4 — Model focus map and guess") |
| | with gr.Row(): |
| | step4 = gr.Image(label="Focus map (where model looked most)") |
| | preds_simple = gr.Textbox(label="Model guesses (Top-5)", lines=4) |
| |
|
| | explanation_simple = gr.Markdown() |
| |
|
| | run_btn.click( |
| | fn=analyze_all, |
| | inputs=[img_input, gr.State(True)], |
| | outputs=[step1, step2, step3, step4, preds_simple, explanation_simple, |
| | gr.State(), gr.Plot(), gr.Textbox(), gr.Markdown(), gr.State()], |
| | ) |
| |
|
| | with gr.TabItem("Advanced (inspect internals)"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | img_adv = gr.Image(label="Upload image for advanced view", type="pil") |
| | run_adv = gr.Button("Analyze (advanced)") |
| | gr.Markdown("Use the sliders to explore attention per layer and head.") |
| | layer_slider = gr.Slider(0, 11, value=11, step=1, label="Layer (0=shallow)") |
| | head_slider = gr.Slider(0, 11, value=0, step=1, label="Head index") |
| | rollout_btn = gr.Button("Refresh Rollout Overlay") |
| | pca_txt = gr.Textbox(label="PCA layers (comma separated)", value="0,6,11") |
| | pca_btn = gr.Button("Update PCA") |
| | with gr.Column(scale=1): |
| | adv_attn = gr.Image(label="Attention overlay (layer/head CLS->patch)") |
| | adv_rollout = gr.Image(label="Attention rollout overlay (aggregated)") |
| | adv_pca = gr.Plot(label="PCA of patch embeddings") |
| | adv_preds = gr.Textbox(label="Top-5 predictions", lines=5) |
| | adv_explain = gr.Markdown() |
| |
|
| | state_box = gr.State() |
| |
|
| | |
| | run_adv.click( |
| | fn=analyze_all, |
| | inputs=[img_adv, gr.State(False)], |
| | outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), adv_preds, gr.Markdown(), |
| | adv_attn, adv_pca, adv_preds, adv_explain, state_box], |
| | ) |
| |
|
| | |
| | layer_slider.change( |
| | fn=advanced_update_attention, |
| | inputs=[state_box, layer_slider, head_slider], |
| | outputs=[adv_attn], |
| | ) |
| | head_slider.change( |
| | fn=advanced_update_attention, |
| | inputs=[state_box, layer_slider, head_slider], |
| | outputs=[adv_attn], |
| | ) |
| |
|
| | rollout_btn.click( |
| | fn=advanced_update_rollout, |
| | inputs=[state_box], |
| | outputs=[adv_rollout], |
| | ) |
| |
|
| | pca_btn.click( |
| | fn=advanced_update_pca, |
| | inputs=[state_box, pca_txt], |
| | outputs=[adv_pca], |
| | ) |
| |
|
| | demo.launch() |