import torch
import torch.nn as nn
from einops import rearrange
from unet import TransformerEncoder
from ..common import expand_to_batch
class ViT(nn.Module):
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
"""
Minimal re-implementation of ViT
Args:
img_dim: the spatial image size
in_channels: number of img channels
patch_dim: desired patch dim
num_classes: classification task classes
dim: the linear layer's dim to project the patches for MHSA
blocks: number of transformer blocks
heads: number of heads
dim_linear_block: inner dim of the transformer linear block
dim_head: dim head in case you want to define it. defaults to dim/heads
dropout: for pos emb and transformer
transformer: in case you want to provide another transformer implementation
classification: creates an extra CLS token that we will index in the final classification layer
"""
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible by img dim {img_dim}'
self.p = patch_dim
self.classification = classification
# tokens = number of patches
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
# Projection and pos embeddings
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
def forward(self, img, mask=None):
# Create patches
# from [batch, channels, h, w] to [batch, tokens , N], N=p*p*c , tokens = h/p *w/p
img_patches = rearrange(img,
'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
batch_size, tokens, _ = img_patches.shape
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
img_patches = torch.cat((expand_to_batch(self.cls_token, desired_size=batch_size), img_patches), dim=1)
# add pos. embeddings. + dropout
# indexing with the current batch's token length to support variable sequences
img_patches = img_patches + self.pos_emb1D[:tokens + 1, :]
patch_embeddings = self.emb_dropout(img_patches)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :]) if self.classification else y[:, 1:, :]
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
新建文件夹.7z (25个子文件)
新建文件夹
main.py 2KB
unet
__init__.py 96B
vanilla_transformer
__init__.py 139B
MHSA.py 2KB
__pycache__
SA.cpython-310.pyc 2KB
__init__.cpython-310.pyc 335B
transformer_block.cpython-310.pyc 2KB
MHSA.cpython-310.pyc 2KB
SA.py 1KB
transformer_block.py 2KB
unet_transformer
__init__.py 28B
decoder.py 2KB
unet.py 3KB
__pycache__
decoder.cpython-310.pyc 2KB
__init__.cpython-310.pyc 192B
unet.cpython-310.pyc 2KB
bottleneck_layer.cpython-310.pyc 2KB
bottleneck_layer.py 2KB
common.py 526B
__pycache__
common.cpython-310.pyc 797B
__init__.cpython-310.pyc 254B
vit
vit.py 4KB
__init__.py 21B
__pycache__
vit.cpython-310.pyc 3KB
__init__.cpython-310.pyc 172B
共 25 条
- 1
资源评论
Ai医学图像分割
- 粉丝: 2w+
- 资源: 2285
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功