# How Do Vision Transformers Work?
[[paper](https://openreview.net/forum?id=D78Go4hVcxO), [arxiv](https://arxiv.org/abs/2202.06709), [poster](https://github.com/xxxnell/how-do-vits-work-storage/blob/master/resources/how_do_vits_work_poster_iclr2022.pdf), [slide](https://github.com/xxxnell/how-do-vits-work-storage/blob/master/resources/how_do_vits_work_talk.pdf)]
This repository provides a PyTorch implementation of ["How Do Vision Transformers Work? (ICLR 2022 Spotlight)"](https://openreview.net/forum?id=D78Go4hVcxO) In the paper, we show that the success of multi-head self-attentions (MSAs) for computer vision is ***NOT due to their weak inductive bias and capturing long-range dependency***.
In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):
***Q1. What properties of MSAs do we need to better optimize NNs?***
A1. MSAs have their pros and cons. MSAs improve NNs by flattening the loss landscapes. A key feature is their data specificity (data dependency), not long-range dependency. On the other hand, ViTs suffers from non-convex losses.
***Q2. Do MSAs act like Convs?***
A2. MSAs and Convs exhibit opposite behaviorsâe.g., MSAs are low-pass filters, but Convs are high-pass filters. It suggests that MSAs are shape-biased, whereas Convs are texture-biased. Therefore, MSAs and Convs are complementary.
***Q3. How can we harmonize MSAs with Convs?***
A3. MSAs at the end of a stage (not a model) significantly improve the accuracy. Based on this, we introduce *AlterNet* by replacing Convs at the end of a stage with MSAs. AlterNet outperforms CNNs not only in large data regimes but also in small data regimes.
ð Let's find the detailed answers below!
### I. What Properties of MSAs Do We Need to Improve Optimization?
<p align="center">
<img src="resources/vit/loss-landscape.png" style="width:90%;">
</p>
MSAs improve not only accuracy but also generalization by flattening the loss landscapes. ***Such improvement is primarily attributable to their data specificity, NOT long-range dependency*** ð± On the other hand, ViTs suffers from non-convex losses. Their weak inductive bias and long-range dependency produce negative Hessian eigenvalues in small data regimes, and these non-convex points disrupt NN training. Large datasets and loss landscape smoothing methods alleviate this problem.
### II. Do MSAs Act Like Convs?
<p align="center">
<img src="resources/vit/fourier.png" style="width:90%;">
</p>
MSAs and Convs exhibit opposite behaviors. Therefore, MSAs and Convs are complementary. For example, MSAs are low-pass filters, but Convs are high-pass filters. Likewise, Convs are vulnerable to high-frequency noise but that MSAs are vulnerable to low-frequency noise: it suggests that MSAs are shape-biased, whereas Convs are texture-biased. In addition, Convs transform feature maps and MSAs aggregate transformed feature map predictions. Thus, it is effective to place MSAs after Convs.
### III. How Can We Harmonize MSAs With Convs?
<p align="center">
<img src="resources/vit/architecture.png" style="width:90%;">
</p>
Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage (not the end of a model) play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of the canonical Transformer, which has one MLP block for one MSA block.
<br />
<p align="center">
<img src="resources/vit/alternet.png" style="width:90%;">
</p>
Based on these design rules, we introduce AlterNet ([code](https://github.com/xxxnell/how-do-vits-work/blob/transformer/models/alternet.py)) by replacing Conv blocks at the end of a stage with MSA blocks. ***Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes***, e.g., CIFAR. This contrasts with canonical ViTs, models that perform poorly on small amounts of data. For more details, see below (["How to Apply MSA to Your Own Model"](#how-to-apply-msa-to-your-own-model) section).
This repository is based on [the official implementation of "Blurs Behaves Like Ensembles: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness"](https://github.com/xxxnell/spatial-smoothing). In this paper, we show that a simple (non-trainable) 2 â 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!
## Getting Started
The following packages are required:
* pytorch
* matplotlib
* notebook
* ipywidgets
* timm
* einops
* tensorboard
* seaborn (optional)
We mainly use docker images `pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime` for the code.
See [```classification.ipynb```](classification.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/classification.ipynb)) for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.
**Metrics.** We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, â) and Acc for 90% certain results (Acc-90, â), negative log-likelihood (NLL, â), Expected Calibration Error (ECE, â), Intersection-over-Union (IoU, â) and IoU for certain results (IoU-90, â), Unconfidence (Unc-90, â), and Frequency for certain results (Freq-90, â). We also define a method to plot a reliability diagram for visualization.
**Models.** We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default. timm implementations also can be used.
<details>
<summary>
Four pretrained models for CIFAR-100 are also provided: <a href="https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar">ResNet-50</a>, <a href="https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar">ViT-Ti</a>, <a href="https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar">PiT-Ti</a>, and <a href="https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar">Swin-Ti</a>. We recommend using <a href="https://github.com/rwightman/pytorch-image-models">timm</a> for ImageNet-1K (e.g., please refer to <code><a href="https://github.com/xxxnell/how-do-vits-work/blob/transformer/fourier_analysis.ipynb">fourier_analysis.ipynb</a></code>).
</summary>
<br/>
The codes below are snippets for (a) loading pretrained models and (b) converting them into block sequences.
<br/>
```python
# ResNet-50
import models
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
models.download(url=url, path=path)
name = "resnet_50"
model = models.get_model(name, num_classes=num_classes, # timm does not provide a ResNet for CIFAR
stem=model_args.get("stem", False))
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])
# b. model â blocks. `blocks` is a sequence of blocks
blocks = [
model.layer0,
*model.layer1,
*model.layer2,
*model.layer3,
*model.layer4,
model.classifier,
]
```
```python
# ViT-Ti
import