import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
from torch.autograd import Variable
def conv2d(in_channels, out_channels, kernel_size = 3, padding = 1):
return nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, padding = padding)
def deconv2d(in_channels, out_channels, kernel_size = 3, padding = 1):
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, padding = padding)
def relu(inplace = True): # Change to True?
return nn.ReLU(inplace)
def maxpool2d():
return nn.MaxPool2d(2)
def make_conv_layers(cfg):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [maxpool2d()]
else:
conv = conv2d(in_channels, v)
layers += [conv, relu(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def make_deconv_layers(cfg):
layers = []
in_channels = 512
for v in cfg:
if v == 'U':
layers += [nn.Upsample(scale_factor=2)]
else:
deconv = deconv2d(in_channels, v)
layers += [deconv]
in_channels = v
return nn.Sequential(*layers)
cfg = {
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
'D': [512, 512, 512, 'U', 512, 512, 512, 'U', 256, 256, 256, 'U', 128, 128, 'U', 64, 64]
}
def encoder():
return make_conv_layers(cfg['E'])
def decoder():
return make_deconv_layers(cfg['D'])
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder = encoder()
self.decoder = decoder()
self.mymodules = nn.ModuleList([
deconv2d(64,1,kernel_size=1, padding = 0),
nn.Sigmoid()
])
def forward(self,x): #
#print('Input x', x.size())
x = self.encoder(x)
#print('After encoder = ', x.size())
x = self.decoder(x)
#print('After decoder = ', x.size())
x = self.mymodules[0](x)
x = self.mymodules[1](x)
#print('Final size = ', x.size())
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.convs = nn.Sequential( # [-1, 4, 256,192]
conv2d(4, 3, kernel_size=1),
relu(),
conv2d(3, 32, kernel_size=3), # [-1, 32, 256, 192]
relu(),
maxpool2d(),
conv2d(32, 64, kernel_size=3), # [-1, 64, 128, 96]
relu(),
conv2d(64, 64, kernel_size=3), # [-1, 64, 128, 96]
relu(),
maxpool2d(), # [-1,64,64,48]
conv2d(64, 64, kernel_size=3), # [-1,64,64,48]
relu(),
conv2d(64, 64, kernel_size=3),
relu(),
maxpool2d(), # [-1,64,32,24]
)
self.mymodules = nn.ModuleList([
nn.Sequential(nn.Linear(64 * 32 * 24, 100), nn.Tanh()),
nn.Sequential(nn.Linear(100, 2), nn.Tanh()),
nn.Sequential(nn.Linear(2, 1), nn.Sigmoid())
])
# self._initialize_weights()
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self.num_flat_features(x))
x = self.mymodules[0](x)
x = self.mymodules[1](x)
x = self.mymodules[2](x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# print(m.weight.data.shape)
# print('old conv2d layer!')
# print(m.weight.data.min())
# print(m.weight.data.max())
m.weight.data.normal_(0, math.sqrt(2. / n))
# print('new conv2d layer!')
# print(m.weight.data.min())
# print(m.weight.data.max())
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()