import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers, regularizers
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Dense, Activation, BatchNormalization, Dropout, Lambda, Concatenate, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.utils import to_categorical
from sklearn.utils import shuffle
import numpy as np
from tensorflow.python.keras.layers import Layer
import anndata
import scbean.tools.utils as tl
def sampling(args):
z_mean, z_log_var = args
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon
#
# def reverse_gradient(X, hp_lambda):
# '''Flips the sign of the incoming gradient during training.'''
# try:
# reverse_gradient.num_calls += 1
# except AttributeError:
# reverse_gradient.num_calls = 1
#
# grad_name = "GradientReversal%d" % reverse_gradient.num_calls
#
# @tf.RegisterGradient(grad_name)
# def _flip_gradients(op, grad):
# return [tf.negative(grad) * hp_lambda]
#
# g = tf.compat.v1.get_default_graph()
# with g.gradient_override_map({'Identity': grad_name}):
# y = tf.identity(X)
#
# return y
#
#
# class GradientReversal(Layer):
# '''Flip the sign of gradient during training.'''
# def __init__(self, hp_lambda, **kwargs):
# super(GradientReversal, self).__init__(**kwargs)
# self.supports_masking = False
# self.hp_lambda = hp_lambda
#
# def build(self, input_shape):
# self._trainable_weights = []
#
# def call(self, x, mask=None):
# return reverse_gradient(x, self.hp_lambda)
#
# def get_output_shape_for(self, input_shape):
# return input_shape
#
# def get_config(self):
# config = {'hp_lambda': self.hp_lambda}
# base_config = super(GradientReversal, self).get_config()
# return dict(list(base_config.items()) + list(config.items()))
@tf.custom_gradient
def grad_reverse(x):
y = tf.identity(x)
def custom_grad(dy):
return -dy
return y, custom_grad
class GradReverse(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, x):
return grad_reverse(x)
class DAVAE:
def __init__(self, input_size, batches=2, domain_scale_factor=1.0,
hidden_layers=[128, 64, 32, 5], path=''):
self.input_size = input_size
self.path = path
self.dann_vae = None
self.inputs = None
self.outputs_x = None
self.initializers = "glorot_uniform"
self.optimizer = optimizers.Adam(lr=0.01)
self.hidden_layers = hidden_layers
self.domain_scale_factor = domain_scale_factor
self.dropout_rate_small = 0.01
self.dropout_rate_big = 0.05
self.kernel_regularizer = regularizers.l1_l2(l1=0.00, l2=0.00)
self.validation_split = 0.0
self.batches = batches
self.dropout_rate = 0.01
callbacks = []
checkpointer = ModelCheckpoint(filepath=path + "vae_weights.h5", verbose=1, save_best_only=False,
save_weights_only=True)
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.8, patience=100, min_lr=0.0001)
early_stop = EarlyStopping(monitor='loss', patience=200)
tensor_board = TensorBoard(log_dir=path + 'logs/')
callbacks.append(checkpointer)
callbacks.append(reduce_lr)
callbacks.append(early_stop)
callbacks.append(tensor_board)
self.callbacks = callbacks
def build(self):
Relu = "relu"
en_ly_size = len(self.hidden_layers)
z_size = self.hidden_layers[en_ly_size - 1]
inputs_x = Input(shape=(self.input_size,), name='inputs')
inputs_batch = Input(shape=(self.batches,), name='inputs_batch')
inputs_loss_weights = Input(shape=(1,), name='inputs_weights')
x = inputs_x
for i in range(en_ly_size):
if i == en_ly_size - 1:
break
ns = self.hidden_layers[i]
x = Dense(ns, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(x)
x = BatchNormalization(center=True, scale=False)(x)
x = Activation(Relu)(x)
x = Dropout(self.dropout_rate)(x)
hx_mean = Dense(z_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers,
name="hx_mean")(x)
hx_log_var = Dense(z_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers,
name="hx_log_var")(x)
hx_z = Lambda(sampling, output_shape=(z_size,), name='hx_z')([hx_mean, hx_log_var])
encoder_hx = Model(inputs_x, [hx_mean, hx_log_var, hx_z], name='encoder_hx')
latent_inputs_x = Input(shape=(z_size,), name='latent')
x = latent_inputs_x
for i in range(en_ly_size - 1, 0, -1):
ns = self.hidden_layers[i - 1]
x = Dense(ns, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(x)
x = BatchNormalization(center=True, scale=False)(x)
x = Activation(Relu)(x)
x = Dropout(self.dropout_rate_big)(x)
outputs_x = Dense(self.input_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers, activation="softplus")(x)
decoder_x = Model(latent_inputs_x, outputs_x, name='decoder_x')
latent_inputs_batch = Input(shape=(z_size,), name='latent_domain')
Flip = GradReverse()
d = Flip(latent_inputs_batch)
d = Dense(16, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(d)
d = BatchNormalization(center=True, scale=False)(d)
d = Activation(Relu)(d)
d = Dropout(self.dropout_rate_big)(d)
d = Dense(self.batches, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers,
activation="softmax")(d)
domian_classifier = Model(latent_inputs_batch, d, name='domain_classifier')
outputs_x = decoder_x(encoder_hx(inputs_x)[2])
domain_pred = domian_classifier(encoder_hx(inputs_x)[2])
dann_vae = Model([inputs_x, inputs_batch, inputs_loss_weights], [outputs_x, domain_pred], name='vae_mlp')
inputs_x = tf.multiply(inputs_x, inputs_loss_weights)
outputs_x = tf.multiply(outputs_x, inputs_loss_weights)
reconstruction_loss = mse(inputs_x, outputs_x)
# reconstruction_loss = mse(inputs_x, outputs_x)
noise = tf.math.subtract(inputs_x, outputs_x)
var = tf.math.reduce_variance(noise)
reconstruction_loss *= (0.5*self.input_size)/var
reconstruction_loss += (0.5*self.input_size)/var*tf.math.log(var)
kl_loss_z = -0.5 * K.sum(1 + hx_log_var - K.square(hx_mean) - K.exp(hx_log_var), axis=-1)
pred_loss = K.categorical_crossentropy(inputs_batch, domain_pred)*self.input_size*self.domain_scale_factor
vae_loss = K.mean(reconstruction_loss + kl_loss_z + pred_loss)
dann_vae.add_loss(vae_loss)
self.dann_vae = dann_vae
self.encoder = encoder_hx
self.decoder = decoder_x
def compile(self):
self.dann_vae.compile(optimizer=self.optimizer)
self.dann_vae.summary()
def train(self, x, batch, loss_weights, batch_size=100, epochs=300):
history = self.dann_vae.fit({'inputs': x, 'inputs_batch': batch, 'inputs_weights': loss_weights},
epochs=epochs, batch_size=batch_size,
validation_split=self.validation_split, shuffle=True)
return history
def get_output(self, x, batches,):
[z_mean, z_log_var, z] = self.encoder.predict(x)
output_x = self.decoder.predict(z_mean)
return z_mean, output_x
class DACVAE:
def __init__(self, input_size, batches=2, domain_scale_factor=1.0, hidden_layers=[128, 64, 32, 5], path=''):
self.input_size = input_size
self.path = path
self.dann_vae = None
self.inputs = None
self.outputs_x = None
self.initializers = "glorot_uniform"
self.optimizer = optimizers.Adam(lr=0.01)
self.hidden_layers = hidden_layers
self.dropout_rate_small = 0.01
self.dropout_rate_big = 0.1
self.kernel_regularizer = regularizers.l1_l2(l1=0.00, l2=0.00)
self.domain_scale_factor = domain_scale_factor
self.validation_split = 0.0
self.batches = batches
self.dropout_rate = 0.01
callbacks = []
checkpointer = ModelCheckpoint(filepath=path + "vae_weights.h5", verbose=1, save_best_only=False,
save_weights_only=True)
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.8, patience=100, min_lr=0.0001)
early_stop = EarlyStopping(monitor='loss', patience=200)
tensor_board = TensorBoard(log_dir=path + 'logs/')
callbacks.append(checkpointer)
callbacks.append(reduce_lr)
callbacks.append(early_stop)
callbacks.append(tensor_board)
self.callbacks = callbacks
def build(self):
Relu = "relu"
en_ly_size = len(self.hidden_layers)
z_size = self.hidden_layers[en_ly_size-1]
inputs_x = Input(shape=(self.input_size,), name='inputs')
inputs_batch = Input(shape=(self.batches,), name='inputs_batch')
inputs_loss_weights = Input(shape=(1,), name='inputs_weights')
inputs = concatenate([inputs_x, inputs_batch])
x = inputs
for i in range(en_ly_size):
if i == en_ly_size - 1:
break
ns = self.hidden_layers[i]
x = Dense(ns, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(x)
x = BatchNormalization(center=True, scale=False)(x)
x = Activation(Relu)(x)
x = Dropout(self.dropout_rate_big)(x)
hx_mean = Dense(z_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers,
name="hx_mean")(x)
hx_log_var = Dense(z_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers,
name="hx_log_var")(x)
hx_z = Lambda(sampling, output_shape=(z_size,), name='hx_z')([hx_mean, hx_log_var])
encoder_hx = Model([inputs_x, inputs_batch], [hx_mean, hx_log_var, hx_z], name='encoder_hx')
latent_inputs_x = Input(shape=(z_size,), name='latent')
latent_inputs_batch = Input(shape=(self.batches,), name='latent_batch')
latent_inputs = concatenate([latent_inputs_x, latent_inputs_batch])
x=latent_inputs
for i in range(en_ly_size - 1, 0, -1):
ns = self.hidden_layers[i - 1]
x = Dense(ns, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(x)
x = BatchNormalization(center=True, scale=False)(x)
x = Activation(Relu)(x)
x = Dropout(self.dropout_rate_big)(x)
outputs_x = Dense(self.input_size, kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.initializers, activation="softplus")(x)
decoder_x = Model([latent_inputs_x, latent_inputs_batch], outputs_x, name='decoder_x')
latent_inputs_domain = Input(shape=(z_size,), name='latent_domain')
Flip = GradReverse()
# d= grad_reverse(latent_inputs_batch)
d = Flip(latent_inputs_domain)
d = Dense(16, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers)(d)
d = BatchNormalization(center=True, scale=False)(d)
d = Activation(Relu)(d)
d = Dropout(self.dropout_rate_small)(d)
d = Dense(self.batches, kernel_regularizer=self.kernel_regularizer, kernel_initializer=self.initializers,
activation="softmax")(d)
domian_classifier = Model(latent_inputs_domain, d, name='domain_classifier')
outputs_x = decoder_x([encoder_hx([inputs_x, inputs_batch])[2], inputs_batch])
domain_pred = domian_classifier(encoder_hx([inputs_x, inputs_batch])[2])
dann_vae = Model([inputs_x, inputs_batch, inputs_loss_weights], [outputs_x, domain_pred], name='vae_mlp')
inputs_x = tf.multiply(inputs_x, inputs_loss_weights)
outputs_x = tf.multiply(outputs_x, inputs_loss_weights)
reconstruction_loss = mse(inputs_x, outputs_x)
# hx_log_var = tf.multiply(hx_log_var, inputs_loss_weights)
# hx_mean = tf.multiply(hx_mean, inputs_loss_weights)
# reconstruction_loss = mse(inputs_x, outputs_x)
noise = tf.math.subtract(inputs_x, outputs_x)
var = tf.math.reduce_variance(noise)
reconstruction_loss *= (0.5*self.input_size)/var
reconstruction_loss += (0.5*self.input_size)/var*tf.math.log(var)
kl_loss_z = -0.5 * K.sum(1 + hx_log_var - K.square(hx_mean) - K.exp(hx_log_var), axis=-1)
pred_loss = K.categorical_crossentropy(inputs_batch, domain_pred)*self.input_size*self.domain_scale_factor
vae_loss = K.mean(reconstruction_loss + kl_loss_z + pred_loss)
dann_vae.add_loss(vae_loss)
self.dann_vae = dann_vae
self.encoder = encoder_hx
self.decoder = decoder_x
def compile(self):
self.dann_vae.compile(optimizer=self.optimizer)
self.dann_vae.summary()
def train(self, x, batch, loss_weights, batch_size=100, epochs=300):
history = self.dann_vae.fit({'inputs': x, 'inputs_batch': batch, 'inputs_weights': loss_weights},
epochs=epochs, batch_size=batch_size,
validation_split=self.validation_split, shuffle=True)
return history
def get_output(self, x, batches,):
[z_mean, z_log_var, z] = self.encoder.predict([x, batches])
output_x = self.decoder.predict([z_mean, batches])
return z_mean, output_x
[docs]def fit_integration(adata, batch_num=2, mode='DACVAE', split_by='batch_label', epochs=20, batch_size=128,
domain_lambda=1.0, sparse=True, hidden_layers=[128,64,32,5]):
"""/
Build DAVAE model and fit the data to the model for training.
Parameters
----------
adata: AnnData
AnnData object need to be integrated.
batch_num: int, optional (default: 2)
Number of batches of datasets to be integrated.
mode: string, optional (default: 'DACVAE')
if 'DACVAE', construct a DACVAE model
if 'DAVAE', construct a DAVAE model
split_by: string, optional (default: '_batch')
the obsm_name of obsm used to distinguish different batches.
epochs: int, optional (default: 200)
Number of epochs to train the model. An epoch is an iteration over the entire x and y data provided.
batch_size: int or None, optional (default: 256)
Number of samples per gradient update. If unspecified, batch_size will default to 32.
domain_lambda: double, optional (default: 1.0)
The coefficient multiplied by the loss value of the domian classifier of DAVAE model.
sparse: bool, optional (default: True)
If True, Matrix X in the AnnData object is stored as a sparse matrix.
hidden_layers: list of integers, (default: [128,64,32,5])
Number of hidden layer neurons in the model.
Returns
-------
:class:`~anndata.AnnData`
out_adata
"""
batch = adata.obs[split_by]
batch = np.array(batch.values, dtype=int)
loss_weight = adata.obs['loss_weight']
# orig_batch, batch_num = tl.generate_batch_code(batch, batch_num)
orig_batch = to_categorical(batch)
if sparse:
orig_data = adata.X.A
else:
orig_data = adata.X
data, batch, loss_weight = shuffle(orig_data, orig_batch, loss_weight, random_state=0)
if mode=='DAVAE':
net = DAVAE(input_size=data.shape[1], batches=batch_num, domain_scale_factor=domain_lambda,
hidden_layers=hidden_layers)
else:
net = DACVAE(input_size=data.shape[1], batches=batch_num, domain_scale_factor=domain_lambda,
hidden_layers=hidden_layers)
net.build()
net.compile()
net.train(data, batch, loss_weight, batch_size=batch_size, epochs=epochs)
latent_z, output_x = net.get_output(orig_data, orig_batch)
out_adata = anndata.AnnData(X=output_x, obs=adata.obs, var=adata.var)
out_adata.obsm['X_davae'] = latent_z
out_adata.raw = adata.copy()
return out_adata