import torch
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
import sys
from typing import Optional, Any
import math
class CNNBlock(nn.Module):
def __init__(self,vertical=False,tcn=False) -> None:
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm2d(1),
nn.Conv2d(1,32,(1,9),stride=(1,2),padding=(0,4)),#[B,32,6,64]
nn.ReLU(),
nn.MaxPool2d((1,2),(1,2)),#[B,32,6,32]
nn.Conv2d(32,64,(1,3),padding=(0,1)),#[B,64,6,32]
nn.ReLU(),
nn.Conv2d(64,128,(1,3),padding=(0,1)),#[B,128,6,32]
nn.ReLU(),
nn.MaxPool2d((1,2),(1,2)),#[B,128,6,16]
nn.Conv2d(128,128,(6,1)),#[B,128,1,16]
nn.ReLU()
)
if vertical:
self.block = nn.Sequential(
nn.BatchNorm2d(1),
nn.Conv2d(1,32,(1,9),stride=(1,2),padding=(0,4)),#[B,32,6,64]
nn.ReLU(),
nn.MaxPool2d((1,2),(1,2)),#[B,32,6,32]
nn.Conv2d(32,64,(1,3),padding=(0,1)),#[B,64,6,32]
nn.ReLU(),
nn.Conv2d(64,128,(1,3),padding=(0,1)),#[B,128,6,32]
nn.ReLU(),
nn.MaxPool2d((1,2),(1,2)),#[B,128,6,16]
nn.Conv2d(128,128,(11,1)),#[B,128,1,16]
nn.ReLU()
)
def forward(self,x):
return self.block(x)
########################################################################################################
class SiaCNNAuth(nn.Module):
def __init__(self) -> None:
super().__init__()
self.cnn_block = CNNBlock(tcn=True)
self.linear1 = nn.Linear(2048,512)
self.linear2 = nn.Linear(512,32)
def forward_once(self,x):
x = x.reshape(-1,1,6,128)
x = self.cnn_block(x)
x = torch.flatten(x,start_dim=1)
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x
def forward(self,x1,x2):
output1,output2 = self.forward_once(x1),self.forward_once(x2)
#output = abs(self.forward_once(x1)-self.forward_once(x2))
return output1,output2
class OriginCNNAuth(nn.Module):
def __init__(self,vertical=False) -> None:
super().__init__()
self.vertical = vertical
self.cnn_block = CNNBlock(vertical,tcn=True)
self.linear1 = torch.nn.Linear(2048*2,2)
def forward(self,x1,x2):
x1 = x1.reshape(-1,1,6,128)
x2 = x2.reshape(-1,1,6,128)
if self.vertical:
x = torch.cat([x1,x2],dim=2)
else:
x = torch.cat([x1,x2],dim=3)
output = self.cnn_block(x)
x = torch.flatten(output,start_dim=1)
x = self.linear1(x)
return x
#########################################################################################################
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise ValueError("activation should be relu/gelu, not {}".format(activation))
class SiaCNNLSTMAuth(nn.Module):
def __init__(self,input_size=6,hidden_size=1024) -> None:
super().__init__()
self.cnn_block = CNNBlock()
self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True,num_layers=2)
self.linear = nn.Linear(hidden_size+2048,32)
def forward_once(self,x):
y = x.permute(0,2,1)
x = x.reshape(-1,1,6,128)
x = self.cnn_block(x)
x = torch.flatten(x,start_dim=1)#[batch_size,2048]
z,_ = self.lstm(y)
z = z[:,-1,:]#[batch_size, hidden_size]
x = torch.cat([x,z],dim=1)
x = self.linear(x)
return x
def forward(self,x1,x2):
return self.forward_once(x1),self.forward_once(x2)
class SiaCNNLSTMAuthSer(nn.Module):
def __init__(self,input_size=128,hidden_size=1024) -> None:
super().__init__()
self.cnn_block = CNNBlock()
self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True,num_layers=2)
self.linear = nn.Linear(hidden_size,32)
def forward_once(self,x):
x = x.reshape(-1,1,6,128)
x = self.cnn_block(x)#b,128,1,16
x = x.permute(0,3,1,2).squeeze(-1)
z,_ = self.lstm(x)
z = z[:,-1,:]#[batch_size, hidden_size]
x = self.linear(z)
return x
def forward(self,x1,x2):
return self.forward_once(x1),self.forward_once(x2)
class OriginCNNLSTMAuth(nn.Module):
def __init__(self,vertical=False,hidden_size=1024) -> None:
super().__init__()
self.vertical = vertical
self.cnn_block = CNNBlock()
if vertical:
input_size=256
else :
input_size=128
self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True,num_layers=2)
self.linear = nn.Linear(hidden_size,256)
self.linear2 = nn.Linear(256,2)
def forward_once(self,x):
x = x.reshape(-1,1,6,128)
x = self.cnn_block(x)
return x
def forward(self,x1,x2):
x1 = self.forward_once(x1).permute(0,3,1,2).squeeze(-1)
x2 = self.forward_once(x2).permute(0,3,1,2).squeeze(-1)
if self.vertical:
x = torch.cat([x1,x2],dim=2)
else :
x = torch.cat([x1,x2],dim=1)
z,_ = self.lstm(x)
z = z[:,-1,:]
x = torch.sigmoid(self.linear(z))
return self.linear2(x)
class CNNLSTMAuthPara(nn.Module):
def __init__(self,vertical=False,hidden_size=1024) -> None:
super().__init__()
self.vertical = vertical
self.cnn_block = CNNBlock(vertical)
if vertical:
input_size=12
else :
input_size=6
self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True,num_layers=2)
self.linear = nn.Linear(hidden_size+2048*2,512)
self.linear2 = nn.Linear(512,2)
def forward(self,x1,x2):
x1 = x1.reshape(-1,1,6,128)
x2 = x2.reshape(-1,1,6,128)
if self.vertical:
x = torch.cat([x1,x2],dim=2)
else:
x = torch.cat([x1,x2],dim=3)
output1 = self.cnn_block(x)
output1 = torch.flatten(output1,start_dim=1)
output2,_ = self.lstm(x.permute(0,3,2,1).squeeze(-1))
output2 = F.relu(output2[:,-1,:])
x = torch.cat([output1,output2],dim=1)
x = F.relu(self.linear(x))
return self.linear2(x)
########################################################################################################################
# From https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class FixedPositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=1024).
"""
def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
super(FixedPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # positional encoding
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)