A Unified View of Masked Image Modeling
Paper • 2210.10615 • Published
The first open-source PyTorch implementation of MaskDistill with pre-trained weights.
This model was trained using the MaskDistill-PyTorch codebase, reproducing the method from "A Unified View of Masked Image Modeling".
MaskDistill learns visual representations by distilling knowledge from a frozen CLIP ViT-B/16 teacher into a ViT-Base student through masked image modeling. The student learns to predict the teacher's features for masked patches using Smooth L1 loss.
| Evaluation | Result |
|---|---|
| Finetuning (ImageNet-1K) | 84.8% top-1 |
| k-NN (k=10) | 75.6% top-1 |
| Linear Probe | 76.3% top-1 |
| Sem. Seg. (ADE20K, UPerNet) | 52.6 mIoU |
| Obj. Det. (COCO, Mask R-CNN) | 44.4 bbox mAP |
| Inst. Seg. (COCO, Mask R-CNN) | 40.1 segm mAP |
| File | Description |
|---|---|
pretrain_vit_base_ep290.pth |
Pretrained ViT-Base (300 epochs) |
finetune_vit_base_ep100.pth |
Finetuned on ImageNet-1K (84.8% top-1) |
linprobe_vit_base_ep90.pth.tar |
Linear probe (90 epochs, 76.3% top-1) |
semseg_upernet_ade20k_160k.pth |
UPerNet on ADE20K (52.6 mIoU) |
detection_maskrcnn_coco_12ep.pth |
Mask R-CNN on COCO (44.4 mAP) |
import torch
from src.models.vision_transformer import VisionTransformerMIM
# Load pretrained model
model = VisionTransformerMIM(
img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
use_shared_rel_pos_bias=True, use_mask_tokens=True,
)
ckpt = torch.load("pretrain_vit_base_ep290.pth", map_location="cpu")
state = {k.replace("module.student.", ""): v for k, v in ckpt["model"].items() if "student" in k}
model.load_state_dict(state, strict=False)
See the GitHub repo for full training and evaluation code.
@article{hou2022unified,
title={A Unified View of Masked Image Modeling},
author={Hou, Zhenda and Sun, Fei and Chen, Yun-Hao and Yuan, Jia-Hong and Yu, Jia-Mu},
journal={arXiv preprint arXiv:2210.10615},
year={2022}
}