from tkinter import X
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import math
#This part realizes the quantization and dequantization operations.
#The output of the encoder must be the bitstream.
def Num2Bit(Num, B):
Num_ = Num.numpy()
bit = (np.unpackbits(np.array(Num_, np.uint8), axis=1).reshape(-1, Num_.shape[1], 8)[:, :, 4:]).reshape(-1, Num_.shape[1] * B)
return tf.convert_to_tensor(bit, dtype=tf.float32)
def Bit2Num(Bit, B):
Bit_ = Bit.numpy()
Bit_ = np.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B])
num = np.zeros(shape=np.shape(Bit_[:, :, 1]))
for i in range(B):
num = num + Bit_[:, :, i] * 2 ** (B - 1 - i)
return tf.cast(num, dtype=tf.float32)
def QuantizationOp(x, B):
step = tf.cast((2 ** B), dtype=tf.float32)
result = tf.cast((tf.round(x * step - 0.5)), dtype=tf.float32)
result = tf.py_function(func=Num2Bit, inp=[result, B], Tout=tf.float32)
def custom_grad(dy):
grad = dy
return (grad, grad)
return result, custom_grad
class QuantizationLayer(tf.keras.layers.Layer):
def __init__(self, B,**kwargs):
self.B = B
super(QuantizationLayer, self).__init__()
def call(self, x):
return QuantizationOp(x, self.B)
def get_config(self):
# Implement get_config to enable serialization. This is optional.
base_config = super(QuantizationLayer, self).get_config()
base_config['B'] = self.B
return base_config
def DequantizationOp(x, B):
x = tf.py_function(func=Bit2Num, inp=[x, B], Tout=tf.float32)
step = tf.cast((2 ** B), dtype=tf.float32)
result = tf.cast((x + 0.5) / step, dtype=tf.float32)
def custom_grad(dy):
grad = dy * 1
return (grad, grad)
return result, custom_grad
class DeuantizationLayer(tf.keras.layers.Layer):
def __init__(self, B,**kwargs):
self.B = B
super(DeuantizationLayer, self).__init__()
def call(self, x):
return DequantizationOp(x, self.B)
def get_config(self):
base_config = super(DeuantizationLayer, self).get_config()
base_config['B'] = self.B
return base_config
image_size = 128 # We'll resize input images to this size
patch_size = 16 # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 256
num_heads = 4
transformer_units = [
projection_dim * 2,
] # Size of the transformer layers
transformer_layers = 8
dec_layers = 8
mlp_head_units = [512, 256] # Size of the dense layers of the final classifier
def mlp(x, hidden_units, dropout_rate, trainable=True):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu, trainable=trainable)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
'patch_size': self.patch_size
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim, trainable):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim,trainable=trainable)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim,trainable=trainable
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
return config
# transformer
def Encoder(x, feedback_bits, trainable=True):
B = 4
with tf.compat.v1.variable_scope('Encoder'):
# pad 0.5 to 128*128
x=layers.ZeroPadding2D(padding=(1, 0))(x)
x = layers.Conv2D(32, 7, padding='same', trainable=False,name="enc_conv_1")(x)
x = layers.BatchNormalization(trainable=False,name="enc_bn_1")(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(16, 7, padding='same', trainable=False,name="enc_conv_2")(x)
x = layers.BatchNormalization(trainable=False,name="enc_bn_2")(x)
y = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(2, 7, padding='same', trainable=False,name="enc_conv_3")(y)
x = layers.BatchNormalization(trainable=False,name="enc_bn_3")(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(32, 7, padding='same', trainable=False,name="enc_conv_4")(x)
x = layers.BatchNormalization(trainable=False,name="enc_bn_4")(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(16, 7, padding='same', trainable=False,name="enc_conv_5")(x)
x = layers.BatchNormalization(trainable=False,name="enc_bn_5")(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Add()([x, y])
x = layers.Conv2D(2, 7, padding='same', trainable=False,name="enc_conv_6")(x)
x = layers.BatchNormalization(trainable=False,name="enc_bn_6")(x)
x = layers.LeakyReLU(alpha=0.1)(x)
# Augment data.
# augmented = data_augmentation(inputs)
# x = layers.Normalization()(x)
# Create patches.
patches = Patches(patch_size)(x)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim, trainable=trainable)(patches)
# Create multiple layers of the Transformer block.
# for _ in range(transformer_layers):
# # Layer normalization 1.
# x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS, trainable=trainable)(encoded_patches)
# # Create a multi-head attention layer.
# attention_output = layers.MultiHeadAttention(
# num_heads=num_heads, key_dim=projection_dim, dropout=0.1, trainable=trainable
# )(x1, x1)
# # Skip connection 1.
# x2 = layers.Add()([attention_output, encoded_patches])
# # Layer normalization 2.
# x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS, trainable=trainable)(x2)
# # MLP.
# x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1, trainable=trainable)
# # Skip connection 2.
# encoded_patches = layers.Add()([x3, x2])
# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS,trainable=trainable,name="enc_ln_3")(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.1)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.1,trainable=trainable)
# features = representation
# Classify outputs.