import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Layer, Embedding
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
import numpy as np
import os
param_model_path = "models/cvae/"
param_result_path = "result/cvae/"
if not os.path.exists(param_model_path):
os.makedirs(param_model_path)
if not os.path.exists(param_result_path):
os.makedirs(param_result_path)
param_img_width, param_img_heigth = 28, 28
param_intermediate_size = 512
param_latent_size = 20
param_batch_size = 512
param_epochs = 100
class Sampling(Layer):
def __init__(self, **kwargs):
super(Sampling, self).__init__(**kwargs)
def call(self, inputs, **kwargs):
mu, logvar = inputs
epsilon = tf.random.normal(shape=tf.shape(mu))
return mu + tf.exp(0.5 * logvar) * epsilon
class VAELOSS(Layer):
def __init__(self, **kwargs):
super(VAELOSS, self).__init__(**kwargs)
def call(self, inputs, **kwargs):
output, realimg, mu, logvar = inputs
bceloss = binary_crossentropy(realimg, output)
klloss = -0.5 * (logvar + 1 - mu ** 2 - tf.exp(logvar))
klloss = tf.reduce_sum(klloss, axis=-1)
loss = tf.reduce_mean(param_img_heigth * param_img_width * bceloss + klloss)
self.add_loss(loss)
return output
class CustomEmbedding(Layer):
def __init__(self, outputsize, **kwargs):
super(CustomEmbedding, self).__init__(**kwargs)
self.outputsize = outputsize
def build(self, input_shape):
self.embedding = Embedding(10, self.outputsize)
super(CustomEmbedding, self).build(input_shape)
def call(self, inputs, **kwargs):
x = self.embedding(inputs)
return tf.squeeze(x, axis=1)
class USER:
def build_model(self, summary=False, plot=False):
realimg = Input(shape=[param_img_heigth * param_img_width, ], name="realimg")
label = Input(shape=[1, ], name="label", dtype=tf.int32)
label_embed = CustomEmbedding(param_img_heigth * param_img_width, name="labelembeddingen")(label)
x = Dense(param_intermediate_size, activation="relu", name="interdense")(realimg + label_embed)
mu = Dense(param_latent_size, name="mu")(x)
logvar = Dense(param_latent_size, name="logvar")(x)
z = Sampling(name="sampling")(inputs=(mu, logvar))
encoder = Model(inputs=[realimg, label], outputs=[mu, logvar, z], name="encoder")
latentimg = Input(shape=[param_latent_size, ], name="latentimg")
label_embed = CustomEmbedding(param_latent_size, name="labelembeddingde")(label)
x = Dense(param_intermediate_size, activation="relu", name="invinterdense")(latentimg + label_embed)
outputimg = Dense(param_img_width * param_img_heigth, activation="sigmoid", name="invimg")(x)
decoder = Model(inputs=[latentimg, label], outputs=outputimg, name="decoder")
output1 = decoder((encoder([realimg, label])[2], label))
output = VAELOSS(name="vaeloss")(inputs=(output1, realimg, mu, logvar))
vae = Model(inputs=[realimg, label], outputs=output, name="vae")
if summary:
encoder.summary()
decoder.summary()
vae.summary()
if plot:
plot_model(encoder, param_model_path + "encoder.png", show_shapes=True)
plot_model(decoder, param_model_path + "decoder.png", show_shapes=True)
plot_model(vae, param_model_path + "vae.png", show_shapes=True)
return encoder, decoder, vae
def train(self):
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, [-1, param_img_width * param_img_heigth])
x_test = np.reshape(x_test, [-1, param_img_width * param_img_heigth])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
encoder, decoder, vae = self.build_model(summary=True, plot=True)
vae.compile(optimizer=Adam(0.01))
vae.fit([x_train, y_train], epochs=param_epochs, batch_size=param_batch_size, validation_data=[x_test, y_test])
encoder.save_weights(param_model_path + "encoder.h5")
decoder.save_weights(param_model_path + "decoder.h5")
vae.save_weights(param_model_path + "vae.h5")
def predict(self):
_, _, vae = self.build_model()
vae.load_weights(param_model_path + "vae.h5")
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = np.reshape(x_test[:11], [-1, param_img_width * param_img_heigth])
x_test = x_test.astype('float32') / 255
y_test = y_test[:11]
y_predict = vae.predict([x_test, y_test])
realimages = np.reshape(x_test * 255, [-1, param_img_width, param_img_heigth, 1])
fakeimages = np.reshape(y_predict * 255, [-1, param_img_width, param_img_heigth, 1])
for i in range(11):
realimg = tf.image.encode_jpeg(realimages[i], quality=100)
fakeimg = tf.image.encode_jpeg(fakeimages[i], quality=100)
with tf.io.gfile.GFile(param_result_path + 'realimg' + str(i) + '_' + str(y_test[i]) + '.jpg',
'wb') as file:
file.write(realimg.numpy())
with tf.io.gfile.GFile(param_result_path + 'fakeimg' + str(i) + '_' + str(y_test[i]) + '.jpg',
'wb') as file:
file.write(fakeimg.numpy())
def test(self):
_, decoder, _ = self.build_model()
decoder.load_weights(param_model_path + "decoder.h5")
latentimgs = tf.random.normal([11, param_latent_size])
labels = tf.random.uniform([11], 0, 10, dtype=tf.int32)
fakeimgs = decoder.predict([latentimgs, labels])
fakeimgs = np.reshape(fakeimgs * 255, [-1, param_img_width, param_img_heigth, 1])
for i in range(11):
fakeimg = tf.image.encode_jpeg(fakeimgs[i], quality=100)
with tf.io.gfile.GFile(param_result_path + 'testfakeimg' + str(i) + '_' + str(labels[i].numpy()) + '.jpg',
'wb') as file:
file.write(fakeimg.numpy())
if __name__ == "__main__":
user = USER()
user.train()
user.predict()
user.test()
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
【资源介绍】 变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip变分自编码器和条件变分自编码器python实现源码+说明.zip 变分自编码器和条件变分自编码器python实现源码+说明.zip 变分自编码器和条件变分自编码器python实现源码+说明.zip 【备注】 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用!有问题请及时沟通交流。 2、适用人群:计算机相关专业(如计科、信息安全、数据科学与大数据技术、人工智能、通信、物联网、自动化、电子信息等)在校学生、专业老师或者企业员工下载使用。 3、用途:项目具有较高的学习借鉴价值,也适用于小白学习入门进阶。当然也可作为毕设项目、课程设计、大作业、初期项目立项演示等。 4、如果基础还行,或者热爱钻研,亦可在此项目代码基础上进行修改添加,实现其他不同功能。 欢迎下载,沟通交流,互相学习,共同进步!
资源推荐
资源详情
资源评论
收起资源包目录
变分自编码器和条件变分自编码器python实现源码+说明.zip (82个子文件)
说明.md 199B
.idea
VAE.iml 326B
other.xml 186B
vcs.xml 180B
misc.xml 297B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 258B
.gitignore 176B
models
cvae
encoder.png 38KB
vae.png 65KB
decoder.png 26KB
vae
encoder.png 27KB
vae.png 43KB
decoder.png 13KB
vae.py 5KB
cvae.py 6KB
result
cvae
testfakeimg1_9.jpg 1018B
testfakeimg4_2.jpg 1020B
realimg4_4.jpg 1KB
testfakeimg10_5.jpg 886B
realimg6_4.jpg 1KB
fakeimg2_1.jpg 723B
realimg9_9.jpg 1007B
fakeimg4_4.jpg 985B
testfakeimg9_5.jpg 930B
realimg2_1.jpg 720B
testfakeimg8_5.jpg 994B
fakeimg1_2.jpg 971B
fakeimg3_0.jpg 965B
realimg7_9.jpg 868B
realimg8_5.jpg 1KB
testfakeimg0_0.jpg 972B
testfakeimg6_6.jpg 910B
fakeimg7_9.jpg 809B
fakeimg8_5.jpg 1KB
testfakeimg5_9.jpg 876B
realimg3_0.jpg 1KB
testfakeimg2_1.jpg 824B
realimg10_0.jpg 1011B
fakeimg0_7.jpg 957B
testfakeimg3_1.jpg 859B
fakeimg5_1.jpg 762B
testfakeimg7_2.jpg 1KB
realimg5_1.jpg 787B
fakeimg10_0.jpg 963B
fakeimg9_9.jpg 1006B
realimg1_2.jpg 982B
fakeimg6_4.jpg 969B
realimg0_7.jpg 970B
vae
fakeimg2.jpg 729B
fakeimg10.jpg 976B
realimg2.jpg 720B
testfakeimg1.jpg 1KB
realimg9.jpg 1007B
testfakeimg7.jpg 1KB
testfakeimg6.jpg 1KB
testfakeimg0.jpg 897B
testfakeimg9.jpg 1KB
testfakeimg2.jpg 917B
testfakeimg8.jpg 896B
testfakeimg4.jpg 712B
testfakeimg10.jpg 886B
realimg7.jpg 868B
realimg5.jpg 787B
testfakeimg5.jpg 959B
realimg1.jpg 982B
fakeimg8.jpg 974B
realimg6.jpg 1KB
testfakeimg3.jpg 1KB
fakeimg5.jpg 761B
fakeimg0.jpg 935B
realimg4.jpg 1KB
realimg0.jpg 970B
realimg8.jpg 1KB
realimg3.jpg 1KB
fakeimg9.jpg 987B
fakeimg1.jpg 1009B
fakeimg7.jpg 878B
fakeimg6.jpg 985B
realimg10.jpg 1011B
fakeimg4.jpg 954B
fakeimg3.jpg 968B
共 82 条
- 1
资源评论
.whl
- 粉丝: 3935
- 资源: 4861
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 冷拉墙板制袋机(含工程图)sw20可编辑全套技术资料100%好用.zip
- 基于小程序的农业电商服务系统源码(小程序毕业设计完整源码+LW).zip
- 可调角度切割机sw18可编辑全套技术资料100%好用.zip
- 基于小程序的农产品自主供销小程序源码(小程序毕业设计完整源码+LW).zip
- 仓储系统web端 vue
- 基于JavaScript的签到管理系统设计源码
- 基于小程序的医笙小程序设计与前端开发源码(小程序毕业设计完整源码).zip
- 仓储系统APP端,uniapp
- 螺旋输送机sw17全套技术资料100%好用.zip
- 基于小程序的医院核酸检测预约挂号源码(小程序毕业设计完整源码+LW).zip
- 密封圈安装机sw18可编辑全套技术资料100%好用.zip
- 基于小程序的医院预约挂号系统小程序源码(小程序毕业设计完整源码+LW).zip
- 基于小程序的同城交易小程序源码(小程序毕业设计完整源码).zip
- 基于小程序的在线办公小程序源码(小程序毕业设计完整源码+LW).zip
- 面板自动上料热熔机(含DFM,BOM)sw17可编辑全套技术资料100%好用.zip
- 奶瓶灌装线step全套技术资料100%好用.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功