# Day 3 作业--Pixel2Pixel:人像卡通化
经过今天的学习,相信大家对图像翻译、风格迁移有了一定的了解啦,是不是也想自己动手来实现下呢?
那么,为了满足大家动手实践的愿望,同时为了巩固大家学到的知识,我们Day 3的作业便是带大家完成一遍课程讲解过的应用--**Pixel2Pixel:人像卡通化**
在本次作业中,大家需要做的是:**补齐代码,跑通训练,提交一张卡通化的成品图,动手完成自己的第一个人像卡通化的应用~**
![](https://www.writebug.com/myres/static/uploads/2022/3/3/7b7cc87a136548f049e10c09e6febc61.writebug)
## 准备工作:引入依赖 & 数据准备
```python
import paddle
import paddle.nn as nn
from paddle.io import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
```
```python
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
```
### 数据准备:
- 真人数据来自[seeprettyface](http://www.seeprettyface.com/mydataset.html)。
- 数据预处理(详情见[photo2cartoon](https://github.com/minivision-ai/photo2cartoon)项目)。
![](https://www.writebug.com/myres/static/uploads/2022/3/3/6e13b60bf623de915004b7ec31150317.writebug)
- 使用[photo2cartoon](https://github.com/minivision-ai/photo2cartoon)项目生成真人数据对应的卡通数据。
```python
# 解压数据
!unzip -qao data/data79149/cartoon_A2B.zip -d data/
```
### 数据可视化
```python
# 训练数据统计
train_names = os.listdir('data/cartoon_A2B/train')
print(f'训练集数据量: {len(train_names)}')
# 测试数据统计
test_names = os.listdir('data/cartoon_A2B/test')
print(f'测试集数据量: {len(test_names)}')
# 训练数据可视化
imgs = []
for img_name in np.random.choice(train_names, 3, replace=False):
imgs.append(cv2.imread('data/cartoon_A2B/train/'+img_name))
img_show = np.vstack(imgs)[:,:,::-1]
plt.figure(figsize=(10, 10))
plt.imshow(img_show)
plt.show()
```
训练集数据量: 1361
测试集数据量: 100
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
![](https://www.writebug.com/myres/static/uploads/2022/3/3/b8af7e301043a7a612af02a7fbf47c1d.writebug)
```python
class PairedData(Dataset):
def __init__(self, phase):
super(PairedData, self).__init__()
self.img_path_list = self.load_A2B_data(phase) # 获取数据列表
self.num_samples = len(self.img_path_list) # 数据量
def __getitem__(self, idx):
img_A2B = cv2.imread(self.img_path_list[idx]) # 读取一组数据
img_A2B = img_A2B.astype('float32') / 127.5 - 1. # 从0~255归一化至-1~1
img_A2B = img_A2B.transpose(2, 0, 1) # 维度变换HWC -> CHW
img_A = img_A2B[..., :256] # 真人照
img_B = img_A2B[..., 256:] # 卡通图
return img_A, img_B
def __len__(self):
return self.num_samples
@staticmethod
def load_A2B_data(phase):
assert phase in ['train', 'test'], "phase should be set within ['train', 'test']"
# 读取数据集,数据中每张图像包含照片和对应的卡通画。
data_path = 'data/cartoon_A2B/'+phase
return [os.path.join(data_path, x) for x in os.listdir(data_path)]
```
```python
paired_dataset_train = PairedData('train')
paired_dataset_test = PairedData('test')
```
## 第一步:搭建生成器
### 请大家补齐空白处的代码,‘#’ 后是提示。
```python
class UnetGenerator(nn.Layer):
def __init__(self, input_nc=3, output_nc=3, ngf=64):
super(UnetGenerator, self).__init__()
self.down1 = nn.Conv2D(input_nc, ngf, kernel_size=4, stride=2, padding=1)
self.down2 = Downsample(ngf, ngf*2)
self.down3 = Downsample(ngf*2, ngf*4)
self.down4 = Downsample(ngf*4, ngf*8)
self.down5 = Downsample(ngf*8, ngf*8)
self.down6 = Downsample(ngf*8, ngf*8)
self.down7 = Downsample(ngf*8, ngf*8)
self.center = Downsample(ngf*8, ngf*8)
self.up7 = Upsample(ngf*8, ngf*8, use_dropout=True)
self.up6 = Upsample(ngf*8*2, ngf*8, use_dropout=True)
self.up5 = Upsample(ngf*8*2, ngf*8, use_dropout=True)
self.up4 = Upsample(ngf*8*2, ngf*8)
self.up3 = Upsample(ngf*8*2, ngf*4)
self.up2 = Upsample(ngf*4*2, ngf*2)
self.up1 = Upsample(ngf*2*2, ngf)
self.output_block = nn.Sequential(
nn.ReLU(),
nn.Conv2DTranspose(ngf*2, output_nc, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
c = self.center(d7)
x = self.up7(c, d7)
x = self.up6(x, d6)
x = self.up5(x, d5)
x = self.up4(x, d4)
x = self.up3(x, d3)
x = self.up2(x, d2)
x = self.up1(x, d1)
x = self.output_block(x)
return x
class Downsample(nn.Layer):
# LeakyReLU => conv => batch norm
def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):
super(Downsample, self).__init__()
self.layers = nn.Sequential(
nn.LeakyReLU(0.2), # LeakyReLU, leaky=0.2
nn.Conv2D(in_dim, out_dim, kernel_size, stride, padding, bias_attr=False), # Conv2D
nn.BatchNorm2D(out_dim)
)
def forward(self, x):
x = self.layers(x)
return x
class Upsample(nn.Layer):
# ReLU => deconv => batch norm => dropout
def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1, use_dropout=False):
super(Upsample, self).__init__()
sequence = [
nn.ReLU(), # ReLU
nn.Conv2DTranspose(in_dim, out_dim, kernel_size, stride, padding, bias_attr=False),
# Conv2DTranspose
nn.BatchNorm2D(out_dim)
]
if use_dropout:
sequence.append(nn.Dropout(p=0.5))
self.layers = nn.Sequential(*sequence)
def forward(self, x, skip):
x = self.layers(x)
x = paddle.concat([x, skip], axis=1)
return x
```
## 第二步:鉴别器的搭建
### 请大家补齐空白处的代码,‘#’ 后是提示。
```python
class NLayerDiscriminator(nn.Layer):
def __init__(self, input_nc=6, ndf=64):
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
资源包含文件:设计报告word+源码 提交一张卡通化的成品图,动手完成自己的第一个人像卡通化的应用~详细介绍参考:https://biyezuopin.blog.csdn.net/article/details/123306950
资源推荐
资源详情
资源评论
收起资源包目录
使用photo2cartoon项目生成真人数据对应的人像卡通化.zip (17个子文件)
使用photo2cartoon项目生成真人数据对应的人像卡通化
pix2pix_photo2cartoon
设计报告.docx 902KB
photo2cartoon.ipynb 737KB
LICENSE 11KB
results
epoch020.png 976KB
epoch040.png 978KB
epoch070.png 979KB
epoch050.png 971KB
epoch080.png 976KB
epoch030.png 974KB
epoch010.png 984KB
epoch060.png 976KB
epoch100.png 977KB
epoch090.png 978KB
output_19_0.png 167KB
requirements.txt 125B
output_6_2.png 337KB
README.md 34KB
共 17 条
- 1
资源评论
- 酷爱1号2023-10-17资源不错,内容挺好的,有一定的使用价值,值得借鉴,感谢分享。
- qq_589876792024-04-05发现一个宝藏资源,资源有很高的参考价值,赶紧学起来~
- 2301_774807702024-04-02终于找到了超赞的宝藏资源,果断冲冲冲,支持!
shejizuopin
- 粉丝: 1w+
- 资源: 1300
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功