#!/usr/bin/env python3
# encoding: utf-8
"""
@Time : 2021/7/7 19:52
@Author : Xie Cheng
@File : transformer.py
@Software: PyCharm
@desc: transformer架构 https://zhuanlan.zhihu.com/p/370481790
"""
import math
import torch
from torch import nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(1), :].squeeze(1)
return self.dropout(x)
class TransformerTS(nn.Module):
def __init__(self,
input_dim,
dec_seq_len,
out_seq_len,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
custom_encoder=None,
custom_decoder=None):
r"""A transformer model. User is able to modify the attributes as needed. The architecture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
model with corresponding parameters.
Args:
input_dim: dimision of imput series
d_model: the number of expected features in the encoder/decoder inputs (default=512).
nhead: the number of heads in the multiheadattention models (default=8).
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
custom_encoder: custom encoder (default=None).
custom_decoder: custom decoder (default=None).
Examples::
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = torch.rand((10, 32, 512)) (time length, N, feature dim)
>>> tgt = torch.rand((20, 32, 512))
>>> out = transformer_model(src, tgt)
Note: A full example to apply nn.Transformer module for the word language model is available in
https://github.com/pytorch/examples/tree/master/word_language_model
"""
super(TransformerTS, self).__init__()
self.transform = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
custom_encoder=custom_encoder,
custom_decoder=custom_decoder
)
self.pos = PositionalEncoding(d_model)
self.enc_input_fc = nn.Linear(input_dim, d_model)
self.dec_input_fc = nn.Linear(input_dim, d_model)
self.out_fc = nn.Linear(dec_seq_len * d_model, out_seq_len)
self.dec_seq_len = dec_seq_len
def forward(self, x):
x = x.transpose(0, 1)
# embedding
embed_encoder_input = self.pos(self.enc_input_fc(x))
embed_decoder_input = self.dec_input_fc(x[-self.dec_seq_len:, :])
# transform
x = self.transform(embed_encoder_input, embed_decoder_input)
# output
x = x.transpose(0, 1)
x = self.out_fc(x.flatten(start_dim=1))
return x.squeeze()
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
<项目介绍> 基于pytorch多头注意力机制实现数字预测源码+模型+数据集 - 不懂运行,下载完可以私聊问,可远程教学 该资源内项目源码是个人的毕设,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 ---------------------------------------------------------------------------------------------------------------------------------------------------------
资源推荐
资源详情
资源评论
收起资源包目录
Transformer_count-main.zip (12个子文件)
Transformer_count-main
myfunction.py 716B
data
上证3.xlsx 1.05MB
shu.csv 249B
stock.csv 196KB
data.csv 6KB
model
tf_model2.pkl 872KB
__pycache__
myfunction.cpython-39.pyc 1KB
Transformer
test_transformer.py 2KB
train_transformer.py 3KB
transformer.py 4KB
__pycache__
myfunction.cpython-39.pyc 1KB
transformer.cpython-39.pyc 4KB
共 12 条
- 1
资源评论
- m0_637029882024-05-10感谢大佬分享的资源,对我启发很大,给了我新的灵感。
机智的程序员zero
- 粉丝: 1958
- 资源: 4206
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功