#!/usr/bin/env python
import warnings
warnings.simplefilter(action='ignore')
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import os
import numpy as np
from tqdm import tqdm
import keras as K
import keras.layers as l
import keras.optimizers as o
from keras.models import Model
from auxilliary import copy_model, WeightClip, wasserstein_loss, one_hot
from batch_generator import BatchGenerator, BatchGenerator_Numpy
class AdaPy():
"""
Simple class for adversarial domain adaptation bases on keras and tensorflow.
Input:
- source_representer : model be trained by source data
- source_classifier : model of classifier be trained by source data
- index_to_label_dictionary : specific way to index labels #not sure what you mean
- algorithm : adda or wadda algorithm choice for training
- domain_discriminator : method to use for domain discriminator #TODO: Add a 1-hidden layer option
- discriminator_lr : learning rate of discriminator
- target_representer_lr : learning rate in training target data
- discriminator_per_representer_iterations : iterations for discriminator training according to representer training
- batch_size : number of samples each batch consist of
- epochs : number of epochs that model will be trained for
- output_directory : directory to save output models
"""
def __init__(self,
source_representer,
source_classifier,
index_to_label_dictionary = None,
algorithm="adda",
domain_discriminator = "linear",
discriminator_lr = 0.001,
target_representer_lr = 0.0002,
discriminator_per_representer_iterations = 10,
discriminator_per_representer_iterations_for0 = 25,
batch_size = 256,
weight_clip_threshold = 0.05,
epochs = 5,
output_directory = "Models/",
lipschitz = "clip"
):
assert algorithm in ["adda", "wadda"], "Invalid choice of algorithm"
assert isinstance(source_representer, K.engine.training.Model) and isinstance(source_classifier, K.engine.training.Model), \
"Provide keras models for source encoder and classifier"
#TODO: Add assertions for all arguments
self.__output_directory = output_directory
self.__algorithm = algorithm
self.__discriminator_learning_rate = discriminator_lr
self.__target_representer_learning_rate = target_representer_lr
self.__weight_clip_threshold = weight_clip_threshold
self.__shuffle = True
self.__epochs = epochs
self.__discriminator_per_representer_iterations = discriminator_per_representer_iterations
self.__discriminator_per_representer_iterations_for0 = discriminator_per_representer_iterations_for0
self.__batch_size = batch_size
self.__latent_dimensions = source_representer.output_shape[1]
self.__shape = source_representer.input_shape
self.__nlabels = source_classifier.output_shape[-1]
if index_to_label_dictionary is None:
self.__index_to_label_dictionary = {k:"" for k in range(self.__nlabels)}
else:
self.__index_to_label_dictionary = index_to_label_dictionary
self.__initialize_models(source_representer, source_classifier, domain_discriminator)
self.__define_models_for_training_and_inference()
self.compile_models()
def compile_models(self):
"""
Method to compile all models of object
"""
if self.__algorithm == "adda":
self.__domain_discriminator.trainable = True
self.__domain_discriminator.compile(loss="binary_crossentropy", optimizer = o.Adam(lr=self.__discriminator_learning_rate))
self.__domain_discriminator.trainable = False
self.__train_target.compile(loss="binary_crossentropy", optimizer = o.Adam(lr=self.__target_representer_learning_rate))
#TODO: Possibly not
self.__target_model.compile(loss="categorical_crossentropy", optimizer = o.Adam(lr=self.__target_representer_learning_rate), metrics=["accuracy"])
if self.__algorithm == "wadda":
self.__domain_discriminator.trainable = True
self.__domain_discriminator.compile(loss=wasserstein_loss, optimizer = o.Adam(lr=self.__discriminator_learning_rate))
self.__domain_discriminator.trainable = False
self.__train_target.compile(loss=wasserstein_loss, optimizer = o.Adam(lr=self.__target_representer_learning_rate))
#TODO: Possibly not
self.__target_model.compile(loss="categorical_crossentropy", optimizer = o.Adam(lr=self.__target_representer_learning_rate), metrics=["accuracy"])
def __build_domain_discriminator(self, domain_discriminator):
if self.__algorithm == "adda":
if domain_discriminator == "linear":
latent_representation = l.Input(shape=(self.__latent_dimensions,))
classifier = l.Dense(1, activation="sigmoid")(latent_representation)
self.__domain_discriminator = Model(latent_representation, classifier)
self.__domain_discriminator.name = "DomainDiscriminator"
else:
assert isinstance(domain_discriminator, K.engine.training.Model), "Provide keras model for domain discriminator"
assert domain_discriminator.output_shape == self.__nlabels, "Domain discriminator must be a binary classifier"
assert domain_discriminator.input_shape == self.__latent_dimensions, "Domain discriminator input dimensionality was invalid"
self.__domain_discriminator = domain_discriminator
if self.__domain_discriminator.name != "DomainDiscriminator":
self.__domain_discriminator.name = "DomainDiscriminator"
elif self.__algorithm == "wadda":
if domain_discriminator == "linear":
latent_representation = l.Input(shape=(self.__latent_dimensions,))
classifier = l.Dense(1, activation = 'linear', kernel_initializer='he_normal',
W_constraint = WeightClip(self.__weight_clip_threshold))(latent_representation)
#TODO: Add a WARNING!
self.__domain_discriminator = Model(latent_representation, classifier)
self.__domain_discriminator.name = "DomainDiscriminator"
def __train_domain_discriminator(self, iterations, target_label, source_label):
for _ in range(iterations):
#TODO:issue Tensorboard
target_latent = self.__target_representer.predict(self.target_data.get_batch()[0])
source_latent = self.__source_representer.predict(self.source_data.get_batch()[0])
#TODO:Handle source batch differently?
self.__domain_discriminator.train_on_batch(target_latent, target_label)
self.__domain_discriminator.train_on_batch(source_latent, source_label)
def __initialize_models(self, source_representer, source_classifier, domain_discriminator):
self.__source_representer = copy_model(source_representer, "SourceRepresenter")
self.__target_representer = copy_model(source_representer, "TargetRepresenter")
self.__source_classifier = copy_model(source_classifier, "Classifier")
self.__build_domain_discriminator(domain_discriminator)
def __define_models_for_training_and_inference(self):
representer_input = l.Input(self.input_shape)
source_representer_output = self.__source_representer(representer_input)
source_classifier = self.__source_classifier(source_representer_output)
self.__source_model = Model(representer_input, source_classifier)
self.__source_model.name = "SourceModel"
target_representer_output = self.__target_representer(representer_input)
target_classifier = self.__source_classifier(target_representer_output)
self.__target_model = Model(representer_input, target_classifier)
self.__target_model.name = "TargetModel"
target_representer_output = self.__target_representer(representer_input)
domain_discriminat
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
资源分类:Python库 所属语言:Python 资源全名:adapy-0.3.tar.gz 资源来源:官方 安装方法:https://lanzao.blog.csdn.net/article/details/101784059
资源推荐
资源详情
资源评论
收起资源包目录
adapy-0.3.tar.gz (7个子文件)
adapy-0.3
PKG-INFO 755B
setup.cfg 39B
adapy
batch_generator.py 4KB
adaptation.py 15KB
__init__.py 241B
auxilliary.py 4KB
setup.py 936B
共 7 条
- 1
资源评论
挣扎的蓝藻
- 粉丝: 14w+
- 资源: 15万+
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 基于Android系统的手机地图应用软件开发中文3.78MB最新版本
- AndroidStudio环境下的jni调用(NDK)的方法中文最新版本
- Vue + UEditor + v-model 实体绑定.zip
- 最新版本ArcGISForAndroidEclipse环境配置中文最新版本
- VS Code 的 Vue 工具 .zip
- AndroidStudio快捷键中文最新版本
- TypeScript 和 Vue 的入门模板,带有详细的 README,描述了如何将两者结合使用 .zip
- The Net Ninja YouTube 频道上的 Vue.js 2 播放列表的课程文件.zip
- TDesign 的 Vue3.x UI 组件库 .zip
- 机器学习,深度学习,卷积神经网络ppt详细说明,详细推导
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功