# vit-keras
This is a Keras implementation of the models described in [An Image is Worth 16x16 Words:
Transformes For Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf). It is based on an earlier implementation from [tuvovan](https://github.com/tuvovan/Vision_Transformer_Keras), modified to match the Flax implementation in the [official repository](https://github.com/google-research/vision_transformer).
The weights here are ported over from the weights provided in the official repository. See `utils.load_weights_numpy` to see how this is done (it's not pretty, but it does the job).
## Usage
Install this package using `pip install vit-keras`
You can use the model out-of-the-box with ImageNet 2012 classes using
something like the following. The weights will be downloaded automatically.
```python
from vit_keras import vit, utils
image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
image_size=image_size,
activation='sigmoid',
pretrained=True,
include_top=True,
pretrained_top=True
)
url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
image = utils.read(url, image_size)
X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)
y = model.predict(X)
print(classes[y[0].argmax()]) # Granny smith
```
You can fine-tune using a model loaded as follows.
```python
image_size = 224
model = vit.vit_l32(
image_size=image_size,
activation='sigmoid',
pretrained=True,
include_top=True,
pretrained_top=False,
classes=200
)
# Train this model on your data as desired.
```
## Visualizing Attention Maps
There's some functionality for plotting attention maps for a given image and model. See example below. I'm not sure I'm doing this correctly (the official repository didn't have example code). Feedback /corrections welcome!
```python
import numpy as np
import matplotlib.pyplot as plt
from vit_keras import vit, utils, visualize
# Load a model
image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
image_size=image_size,
activation='sigmoid',
pretrained=True,
include_top=True,
pretrained_top=True
)
classes = utils.get_imagenet_classes()
# Get an image and compute the attention map
url = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'
image = utils.read(url, image_size)
attention_map = visualize.attention_map(model=model, image=image)
print('Prediction:', classes[
model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()]
) # Prediction: Eskimo dog, husky
# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)
```
data:image/s3,"s3://crabby-images/622a5/622a580acb356a4c6e2f3bcfdc06bb9af31e1e31" alt="example of attention map"
data:image/s3,"s3://crabby-images/efbb1/efbb1bcb631ed12eb6e610e952e488907cd70e82" alt="avatar"
假技术po主
- 粉丝: 533
- 资源: 4461
最新资源
- RTD2280CLW-PCB硬件画板图Wechat rtddisplay.pdf
- RTD2797-PCB硬件画板图Wechat rtddisplay.pdf
- RTD2799-PCB硬件画板图Wechat rtddisplay.pdf
- RTD2383L-PCB硬件画板图Wechat rtddisplay.pdf
- RTD2281CL-PCB硬件画板图Wechat rtddisplay.pdf
- 基于 DeepSeek 的医学文献摘要与问答的 Python 源码
- 安卓必备安装包com.estrongs.android.pop-4.1.5-552,com.speedsoftware.rootexplorer-4.0.5-116
- 5G 内网安全渗透防御实战.zip
- IEEE节点数据:涵盖两区域与多个编号节点,适用于PSSE、PSLF及TSAT仿真,附Matlab仿真支持,IEEE节点数据:用于PSSE、PSLF与TSAT仿真,含Matlab仿真及多区域数据解析
- 基于 DeepSeek 金融市场情绪分析与预测的 Python 源码
- 基于Simulink模型的PID控制主动悬架系统研究与实践:性能提升与驾驶舒适性优化,基于Simulink的PID控制主动悬架模型设计与性能验证:提升驾驶舒适性并优化路面响应,pid控制主动悬架模型
- DeepSeek-V3技术报告 DeepSeek-V3 Technical Report.pdf
- 永磁电机改进超螺旋滑模观测器无位置传感器控制策略:优化动态性能与抗干扰能力,永磁电机改进超螺旋滑模观测器无位置传感器控制策略:优化动态性能与抗干扰能力,永磁电机改进超螺旋滑模观测器无位置传感器控制 S
- Examples.zip
- SQL入门教程:结构化查询语言的详细解析与实际应用案例
- 轴承模型(二维圆柱、二维球与三维深沟球)网格化:Ansys仿真模拟与学习资源包,基于二维圆柱与球模型及三维深沟球有限元网格的Ansys仿真学习资料,助力小白快速上手,内含prepost教程与模型操作指
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
data:image/s3,"s3://crabby-images/64800/6480089faebe1b575565428f4b0911ff02baa1fa" alt="feedback"
data:image/s3,"s3://crabby-images/64800/6480089faebe1b575565428f4b0911ff02baa1fa" alt="feedback"
data:image/s3,"s3://crabby-images/8dc5d/8dc5db4e32f7fe0e912caf189022aff37cbe3642" alt="feedback-tip"