Integrating Visium data

This tutorial including loading, preprocessing, DAVAE integration and visualization of spatial scRNA-seq data. We first integrate Anterior and Posterior mouse brain 10x Genomics Visium data in this tutorial. We then integrate mouse brain visium data with scRNA-seq data.

For more detailed tutorial of Visium data, please refer to scanpy tutorial.

Importing scbean package

Here, we’ll import scbean along with other popular packages.

[1]:
import scbean.model.davae as davae
from scbean.tools import utils as tl
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib
from numpy.random import seed
seed(2021)
matplotlib.use('TkAgg')

# Command for Jupyter notebooks only
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')

DAVAE integration of two spatial gene expression data

Loading data

[2]:
base_path = '/Users/zhongyuanke/data/'
file1_spatial = base_path+'spatial/mouse_brain/10x_mouse_brain_Anterior/'
file2_spatial = base_path+'spatial/mouse_brain/10x_mouse_brain_Posterior/'
file1 = base_path+'spatial/mouse_brain/10x_mouse_brain_Anterior/V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5'
file2 = base_path+'spatial/mouse_brain/10x_mouse_brain_Posterior/V1_Mouse_Brain_Sagittal_Posterior_filtered_feature_bc_matrix.h5'

adata_spatial_anterior = sc.read_visium(file1_spatial, count_file=file1)
adata_spatial_posterior = sc.read_visium(file2_spatial, count_file=file2)
adata_spatial_anterior.var_names_make_unique()
adata_spatial_posterior.var_names_make_unique()
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
[3]:
adata_spatial_anterior
[3]:
AnnData object with n_obs × n_vars = 2695 × 32285
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

Data preprocessing

Here, we filter and normalize each dataset separately and concatenate them into one AnnData object. For more details, please check the preprocessing `API <>`__.

[4]:
adata_spatial = tl.spatial_preprocessing([adata_spatial_anterior, adata_spatial_posterior])

DAVAE integration

The code for integration using davae is as following:

[5]:
adata_integrate = davae.fit_integration(
    adata_spatial,
    epochs=25,
    split_by='loss_weight',
    hidden_layers=[128, 64, 32, 5],
    sparse=True,
    domain_lambda=0.5,
)
adata_spatial.obsm["X_davae"] = adata_integrate.obsm['X_davae']
Model: "vae_mlp"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
inputs (InputLayer)             [(None, 32285)]      0
__________________________________________________________________________________________________
inputs_batch (InputLayer)       [(None, 2)]          0
__________________________________________________________________________________________________
encoder_hx (Functional)         [(None, 5), (None, 5 4144202     inputs[0][0]
                                                                 inputs_batch[0][0]
                                                                 inputs[0][0]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
inputs_weights (InputLayer)     [(None, 1)]          0
__________________________________________________________________________________________________
decoder_x (Functional)          (None, 32285)        4176125     encoder_hx[0][2]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
domain_classifier (Functional)  (None, 2)            178         encoder_hx[1][2]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 32287)        0           inputs[0][0]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          4132864     concatenate[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128)          384         dense[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, 128)          0           batch_normalization[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 128)          0           activation[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 64)           8256        dropout[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64)           192         dense_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64)           0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 64)           0           activation_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 32)           2080        dropout_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 32)           96          dense_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 32)           0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 32)           0           activation_2[0][0]
__________________________________________________________________________________________________
tf.math.subtract (TFOpLambda)   (None, 32285)        0           inputs[0][0]
                                                                 decoder_x[0][0]
__________________________________________________________________________________________________
hx_log_var (Dense)              (None, 5)            165         dropout_2[0][0]
__________________________________________________________________________________________________
hx_mean (Dense)                 (None, 5)            165         dropout_2[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_1 (TFOpLamb (1, 1)               0           tf.math.subtract[0][0]
__________________________________________________________________________________________________
tf.math.multiply_2 (TFOpLambda) (None, 5)            0           hx_log_var[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply_3 (TFOpLambda) (None, 5)            0           hx_mean[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply_1 (TFOpLambda) (None, 32285)        0           decoder_x[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply (TFOpLambda)   (None, 32285)        0           inputs[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.subtract_1 (TFOpLambda) (None, 32285)        0           tf.math.subtract[0][0]
                                                                 tf.math.reduce_mean_1[0][0]
__________________________________________________________________________________________________
tf.__operators__.add_1 (TFOpLam (None, 5)            0           tf.math.multiply_2[0][0]
__________________________________________________________________________________________________
tf.math.square_1 (TFOpLambda)   (None, 5)            0           tf.math.multiply_3[0][0]
__________________________________________________________________________________________________
tf.convert_to_tensor (TFOpLambd (None, 32285)        0           tf.math.multiply_1[0][0]
__________________________________________________________________________________________________
tf.cast (TFOpLambda)            (None, 32285)        0           tf.math.multiply[0][0]
__________________________________________________________________________________________________
tf.math.square (TFOpLambda)     (None, 32285)        0           tf.math.subtract_1[0][0]
__________________________________________________________________________________________________
tf.math.subtract_2 (TFOpLambda) (None, 5)            0           tf.__operators__.add_1[0][0]
                                                                 tf.math.square_1[0][0]
__________________________________________________________________________________________________
tf.math.exp (TFOpLambda)        (None, 5)            0           tf.math.multiply_2[0][0]
__________________________________________________________________________________________________
tf.math.squared_difference (TFO (None, 32285)        0           tf.convert_to_tensor[0][0]
                                                                 tf.cast[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_2 (TFOpLamb ()                   0           tf.math.square[0][0]
__________________________________________________________________________________________________
tf.math.subtract_3 (TFOpLambda) (None, 5)            0           tf.math.subtract_2[0][0]
                                                                 tf.math.exp[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean (TFOpLambda (None,)              0           tf.math.squared_difference[0][0]
__________________________________________________________________________________________________
tf.math.truediv (TFOpLambda)    ()                   0           tf.math.reduce_mean_2[0][0]
__________________________________________________________________________________________________
tf.math.truediv_1 (TFOpLambda)  ()                   0           tf.math.reduce_mean_2[0][0]
__________________________________________________________________________________________________
tf.math.log (TFOpLambda)        ()                   0           tf.math.reduce_mean_2[0][0]
__________________________________________________________________________________________________
tf.math.reduce_sum (TFOpLambda) (None,)              0           tf.math.subtract_3[0][0]
__________________________________________________________________________________________________
tf.math.multiply_4 (TFOpLambda) (None,)              0           tf.math.reduce_mean[0][0]
                                                                 tf.math.truediv[0][0]
__________________________________________________________________________________________________
tf.math.multiply_5 (TFOpLambda) ()                   0           tf.math.truediv_1[0][0]
                                                                 tf.math.log[0][0]
__________________________________________________________________________________________________
tf.math.multiply_6 (TFOpLambda) (None,)              0           tf.math.reduce_sum[0][0]
__________________________________________________________________________________________________
tf.keras.backend.categorical_cr (None,)              0           inputs_batch[0][0]
                                                                 domain_classifier[0][0]
__________________________________________________________________________________________________
tf.__operators__.add (TFOpLambd (None,)              0           tf.math.multiply_4[0][0]
                                                                 tf.math.multiply_5[0][0]
__________________________________________________________________________________________________
tf.math.multiply_9 (TFOpLambda) (None,)              0           tf.math.multiply_6[0][0]
__________________________________________________________________________________________________
tf.math.multiply_7 (TFOpLambda) (None,)              0           tf.keras.backend.categorical_cros
__________________________________________________________________________________________________
tf.__operators__.add_2 (TFOpLam (None,)              0           tf.__operators__.add[0][0]
                                                                 tf.math.multiply_9[0][0]
__________________________________________________________________________________________________
tf.math.multiply_8 (TFOpLambda) (None,)              0           tf.math.multiply_7[0][0]
__________________________________________________________________________________________________
tf.__operators__.add_3 (TFOpLam (None,)              0           tf.__operators__.add_2[0][0]
                                                                 tf.math.multiply_8[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_3 (TFOpLamb ()                   0           tf.__operators__.add_3[0][0]
__________________________________________________________________________________________________
add_loss (AddLoss)              ()                   0           tf.math.reduce_mean_3[0][0]
==================================================================================================
Total params: 8,320,505
Trainable params: 8,319,577
Non-trainable params: 928
__________________________________________________________________________________________________
Epoch 1/25
48/48 [==============================] - 9s 128ms/step - loss: -389110.5081
Epoch 2/25
48/48 [==============================] - 6s 128ms/step - loss: -561719.2908
Epoch 3/25
48/48 [==============================] - 6s 128ms/step - loss: -572639.8929
Epoch 4/25
48/48 [==============================] - 6s 128ms/step - loss: -584418.4860
Epoch 5/25
48/48 [==============================] - 6s 128ms/step - loss: -589162.1071
Epoch 6/25
48/48 [==============================] - 6s 128ms/step - loss: -582521.8865
Epoch 7/25
48/48 [==============================] - 6s 128ms/step - loss: -590242.2857
Epoch 8/25
48/48 [==============================] - 6s 128ms/step - loss: -588344.7066
Epoch 9/25
48/48 [==============================] - 6s 128ms/step - loss: -596586.3890
Epoch 10/25
48/48 [==============================] - 6s 127ms/step - loss: -598210.0217
Epoch 11/25
48/48 [==============================] - 6s 128ms/step - loss: -590263.4260
Epoch 12/25
48/48 [==============================] - 6s 127ms/step - loss: -595876.3814
Epoch 13/25
48/48 [==============================] - 6s 127ms/step - loss: -601443.7500
Epoch 14/25
48/48 [==============================] - 6s 125ms/step - loss: -603635.2360
Epoch 15/25
48/48 [==============================] - 6s 127ms/step - loss: -602548.6416
Epoch 16/25
48/48 [==============================] - 6s 126ms/step - loss: -606242.9719
Epoch 17/25
48/48 [==============================] - 6s 127ms/step - loss: -601354.4311
Epoch 18/25
48/48 [==============================] - 6s 127ms/step - loss: -608207.0102
Epoch 19/25
48/48 [==============================] - 6s 127ms/step - loss: -605368.6684
Epoch 20/25
48/48 [==============================] - 6s 127ms/step - loss: -603970.5918
Epoch 21/25
48/48 [==============================] - 6s 127ms/step - loss: -613009.8278
Epoch 22/25
48/48 [==============================] - 6s 127ms/step - loss: -608529.0459
Epoch 23/25
48/48 [==============================] - 6s 128ms/step - loss: -606853.6327
Epoch 24/25
48/48 [==============================] - 6s 127ms/step - loss: -611104.3750
Epoch 25/25
48/48 [==============================] - 6s 127ms/step - loss: -623454.4184

The embedding representation from vipcca of each cell have been saved in adata.obsm(‘X_davae’)

[6]:
adata_spatial
[6]:
AnnData object with n_obs × n_vars = 6050 × 32285
    obs: 'in_tissue', 'array_row', 'array_col', 'loss_weight', 'library_id'
    var: 'gene_ids', 'feature_types', 'genome', 'highly_variable-V1_Mouse_Brain_Sagittal_Anterior', 'means-V1_Mouse_Brain_Sagittal_Anterior', 'dispersions-V1_Mouse_Brain_Sagittal_Anterior', 'dispersions_norm-V1_Mouse_Brain_Sagittal_Anterior', 'highly_variable-V1_Mouse_Brain_Sagittal_Posterior', 'means-V1_Mouse_Brain_Sagittal_Posterior', 'dispersions-V1_Mouse_Brain_Sagittal_Posterior', 'dispersions_norm-V1_Mouse_Brain_Sagittal_Posterior'
    uns: 'spatial', 'log1p', 'hvg'
    obsm: 'spatial', 'X_davae'

UMAP visualization and clustering

We use UMAP to reduce the embedding feature output by DAVAE in 2 dimensions.

[7]:
sc.set_figure_params(facecolor="white", figsize=(5, 4))
sc.pp.neighbors(adata_spatial, use_rep='X_davae', n_neighbors=12)
sc.tl.umap(adata_spatial)
sc.tl.louvain(adata_spatial, key_added="clusters")
sc.pl.umap(adata_spatial, color=['library_id', "clusters"],
           size=8, color_map='Set2', frameon=False)
... storing 'feature_types' as categorical
... storing 'genome' as categorical
../_images/tutorials_davae_spatial_13_1.png

Visualization in spatial coordinates

For more detailed usage of Visium data visualization, please refer to sc.pl.spatial() function.

[8]:
clusters_colors = dict(
    zip([str(i) for i in range(18)], adata_spatial.uns["clusters_colors"])
)
fig, axs = plt.subplots(1, 2, figsize=(10, 6))

for i, library in enumerate(
    ["V1_Mouse_Brain_Sagittal_Anterior", "V1_Mouse_Brain_Sagittal_Posterior"]
):
    ad = adata_spatial[adata_spatial.obs.library_id == library, :].copy()
    sc.pl.spatial(
        ad,
        img_key="hires",
        library_id=library,
        color="clusters",
        size=1.5,
        palette=[
            v
            for k, v in clusters_colors.items()
            if k in ad.obs.clusters.unique().tolist()
        ],
        legend_loc=None,
        show=False,
        ax=axs[i],
    )

plt.tight_layout()
plt.show()
WARNING: Length of palette colors is smaller than the number of categories (palette length: 17, categories length: 23. Some categories will have the same color.
WARNING: Length of palette colors is smaller than the number of categories (palette length: 18, categories length: 24. Some categories will have the same color.
../_images/tutorials_davae_spatial_15_1.png

DAVAE integration of spatial gene expression and scRNA-seq data

Followed by scanpy, this tutorial also performs data integration on the scRNA-seq dataset and a spatial transcriptomics dataset. This task allows us to transfer the cell type labels identified from the scRNA-seq data set to the Visium data set.

Import additional package

Here, we’ll import additional packages other than the package imported at the beginning

[9]:
import pandas as pd
from sklearn.metrics.pairwise import cosine_distances
import numpy as np

Loading data

The dataset can be downloaded from GEO. Conveniently, you can also download the pre-processed dataset in h5ad format from here.

load scRNA-seq dataset

[10]:
base_path = '/Users/zhongyuanke/data/'
file_rna = base_path+'spatial/mouse_brain/adata_processed_sc.h5ad'
adata_rna = sc.read_h5ad(file_rna)

load visium dataset

[11]:
file1 = base_path+'spatial/mouse_brain/10x_mouse_brain_Anterior/V1_Mouse_Brain_Sagittal_Anterior_filtered_feature_bc_matrix.h5'
file1_spatial = base_path+'spatial/mouse_brain/10x_mouse_brain_Anterior/'
adata_spatial_anterior = sc.read_visium(file1_spatial, count_file=file1)
adata_spatial_anterior.var_names_make_unique()
adata_spatial_anterior = adata_spatial_anterior[
    adata_spatial_anterior.obsm["spatial"][:, 1] < 6000, :
]
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.

Preprocessing

For more detailed usage, please refer to `api <>`__.

[12]:
adata_all = tl.spatial_rna_preprocessing(
    adata_spatial_anterior,
    adata_rna,
)
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.

DAVAE integration

[13]:
adata_integrate = davae.fit_integration(
    adata_all,
    epochs=40,
    batch_size=128,
    domain_lambda=2.5,
    sparse=True,
    hidden_layers=[128, 64, 32, 10]
)
Model: "vae_mlp"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
inputs (InputLayer)             [(None, 3344)]       0
__________________________________________________________________________________________________
inputs_batch (InputLayer)       [(None, 2)]          0
__________________________________________________________________________________________________
encoder_hx (Functional)         [(None, 10), (None,  440084      inputs[0][0]
                                                                 inputs_batch[0][0]
                                                                 inputs[0][0]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
inputs_weights (InputLayer)     [(None, 1)]          0
__________________________________________________________________________________________________
decoder_x (Functional)          (None, 3344)         442896      encoder_hx[0][2]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
domain_classifier (Functional)  (None, 2)            258         encoder_hx[1][2]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 3346)         0           inputs[0][0]
                                                                 inputs_batch[0][0]
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 128)          428416      concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 128)          384         dense_9[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 128)          0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 128)          0           activation_7[0][0]
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 64)           8256        dropout_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 64)           192         dense_10[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 64)           0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 64)           0           activation_8[0][0]
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 32)           2080        dropout_8[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32)           96          dense_11[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 32)           0           batch_normalization_9[0][0]
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, 32)           0           activation_9[0][0]
__________________________________________________________________________________________________
tf.math.subtract_4 (TFOpLambda) (None, 3344)         0           inputs[0][0]
                                                                 decoder_x[0][0]
__________________________________________________________________________________________________
hx_log_var (Dense)              (None, 10)           330         dropout_9[0][0]
__________________________________________________________________________________________________
hx_mean (Dense)                 (None, 10)           330         dropout_9[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_5 (TFOpLamb (1, 1)               0           tf.math.subtract_4[0][0]
__________________________________________________________________________________________________
tf.math.multiply_12 (TFOpLambda (None, 10)           0           hx_log_var[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply_13 (TFOpLambda (None, 10)           0           hx_mean[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply_11 (TFOpLambda (None, 3344)         0           decoder_x[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.multiply_10 (TFOpLambda (None, 3344)         0           inputs[0][0]
                                                                 inputs_weights[0][0]
__________________________________________________________________________________________________
tf.math.subtract_5 (TFOpLambda) (None, 3344)         0           tf.math.subtract_4[0][0]
                                                                 tf.math.reduce_mean_5[0][0]
__________________________________________________________________________________________________
tf.__operators__.add_5 (TFOpLam (None, 10)           0           tf.math.multiply_12[0][0]
__________________________________________________________________________________________________
tf.math.square_3 (TFOpLambda)   (None, 10)           0           tf.math.multiply_13[0][0]
__________________________________________________________________________________________________
tf.convert_to_tensor_3 (TFOpLam (None, 3344)         0           tf.math.multiply_11[0][0]
__________________________________________________________________________________________________
tf.cast_1 (TFOpLambda)          (None, 3344)         0           tf.math.multiply_10[0][0]
__________________________________________________________________________________________________
tf.math.square_2 (TFOpLambda)   (None, 3344)         0           tf.math.subtract_5[0][0]
__________________________________________________________________________________________________
tf.math.subtract_6 (TFOpLambda) (None, 10)           0           tf.__operators__.add_5[0][0]
                                                                 tf.math.square_3[0][0]
__________________________________________________________________________________________________
tf.math.exp_1 (TFOpLambda)      (None, 10)           0           tf.math.multiply_12[0][0]
__________________________________________________________________________________________________
tf.math.squared_difference_1 (T (None, 3344)         0           tf.convert_to_tensor_3[0][0]
                                                                 tf.cast_1[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_6 (TFOpLamb ()                   0           tf.math.square_2[0][0]
__________________________________________________________________________________________________
tf.math.subtract_7 (TFOpLambda) (None, 10)           0           tf.math.subtract_6[0][0]
                                                                 tf.math.exp_1[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_4 (TFOpLamb (None,)              0           tf.math.squared_difference_1[0][0
__________________________________________________________________________________________________
tf.math.truediv_2 (TFOpLambda)  ()                   0           tf.math.reduce_mean_6[0][0]
__________________________________________________________________________________________________
tf.math.truediv_3 (TFOpLambda)  ()                   0           tf.math.reduce_mean_6[0][0]
__________________________________________________________________________________________________
tf.math.log_1 (TFOpLambda)      ()                   0           tf.math.reduce_mean_6[0][0]
__________________________________________________________________________________________________
tf.math.reduce_sum_1 (TFOpLambd (None,)              0           tf.math.subtract_7[0][0]
__________________________________________________________________________________________________
tf.math.multiply_14 (TFOpLambda (None,)              0           tf.math.reduce_mean_4[0][0]
                                                                 tf.math.truediv_2[0][0]
__________________________________________________________________________________________________
tf.math.multiply_15 (TFOpLambda ()                   0           tf.math.truediv_3[0][0]
                                                                 tf.math.log_1[0][0]
__________________________________________________________________________________________________
tf.math.multiply_16 (TFOpLambda (None,)              0           tf.math.reduce_sum_1[0][0]
__________________________________________________________________________________________________
tf.keras.backend.categorical_cr (None,)              0           inputs_batch[0][0]
                                                                 domain_classifier[0][0]
__________________________________________________________________________________________________
tf.__operators__.add_4 (TFOpLam (None,)              0           tf.math.multiply_14[0][0]
                                                                 tf.math.multiply_15[0][0]
__________________________________________________________________________________________________
tf.math.multiply_19 (TFOpLambda (None,)              0           tf.math.multiply_16[0][0]
__________________________________________________________________________________________________
tf.math.multiply_17 (TFOpLambda (None,)              0           tf.keras.backend.categorical_cros
__________________________________________________________________________________________________
tf.__operators__.add_6 (TFOpLam (None,)              0           tf.__operators__.add_4[0][0]
                                                                 tf.math.multiply_19[0][0]
__________________________________________________________________________________________________
tf.math.multiply_18 (TFOpLambda (None,)              0           tf.math.multiply_17[0][0]
__________________________________________________________________________________________________
tf.__operators__.add_7 (TFOpLam (None,)              0           tf.__operators__.add_6[0][0]
                                                                 tf.math.multiply_18[0][0]
__________________________________________________________________________________________________
tf.math.reduce_mean_7 (TFOpLamb ()                   0           tf.__operators__.add_7[0][0]
__________________________________________________________________________________________________
add_loss_1 (AddLoss)            ()                   0           tf.math.reduce_mean_7[0][0]
==================================================================================================
Total params: 883,238
Trainable params: 882,310
Non-trainable params: 928
__________________________________________________________________________________________________
Epoch 1/40
182/182 [==============================] - 5s 14ms/step - loss: -10324.3097
Epoch 2/40
182/182 [==============================] - 3s 15ms/step - loss: -16200.0644
Epoch 3/40
182/182 [==============================] - 3s 14ms/step - loss: -18638.1599
Epoch 4/40
182/182 [==============================] - 3s 14ms/step - loss: -19375.9361
Epoch 5/40
182/182 [==============================] - 3s 14ms/step - loss: -20050.2633
Epoch 6/40
182/182 [==============================] - 3s 14ms/step - loss: -20510.9002
Epoch 7/40
182/182 [==============================] - 3s 14ms/step - loss: -20895.7268
Epoch 8/40
182/182 [==============================] - 3s 14ms/step - loss: -21202.0569
Epoch 9/40
182/182 [==============================] - 3s 14ms/step - loss: -21417.6603
Epoch 10/40
182/182 [==============================] - 3s 14ms/step - loss: -21656.4537
Epoch 11/40
182/182 [==============================] - 3s 14ms/step - loss: -21741.0907
Epoch 12/40
182/182 [==============================] - 3s 14ms/step - loss: -21918.3116
Epoch 13/40
182/182 [==============================] - 3s 14ms/step - loss: -21963.5711
Epoch 14/40
182/182 [==============================] - 3s 14ms/step - loss: -22076.1964
Epoch 15/40
182/182 [==============================] - 3s 14ms/step - loss: -22277.3562
Epoch 16/40
182/182 [==============================] - 3s 14ms/step - loss: -22145.0867
Epoch 17/40
182/182 [==============================] - 3s 14ms/step - loss: -22529.5795
Epoch 18/40
182/182 [==============================] - 3s 14ms/step - loss: -22361.3389
Epoch 19/40
182/182 [==============================] - 3s 14ms/step - loss: -22503.2769
Epoch 20/40
182/182 [==============================] - 3s 14ms/step - loss: -22559.2437
Epoch 21/40
182/182 [==============================] - 3s 14ms/step - loss: -22448.5279
Epoch 22/40
182/182 [==============================] - 3s 14ms/step - loss: -22959.0903
Epoch 23/40
182/182 [==============================] - 3s 14ms/step - loss: -22887.3262
Epoch 24/40
182/182 [==============================] - 3s 14ms/step - loss: -22861.7095
Epoch 25/40
182/182 [==============================] - 3s 14ms/step - loss: -22851.5816
Epoch 26/40
182/182 [==============================] - 3s 14ms/step - loss: -22857.8835
Epoch 27/40
182/182 [==============================] - 3s 14ms/step - loss: -22866.5289
Epoch 28/40
182/182 [==============================] - 3s 14ms/step - loss: -22814.3243
Epoch 29/40
182/182 [==============================] - 3s 14ms/step - loss: -22986.6867
Epoch 30/40
182/182 [==============================] - 3s 14ms/step - loss: -22757.1842
Epoch 31/40
182/182 [==============================] - 3s 14ms/step - loss: -22919.4136
Epoch 32/40
182/182 [==============================] - 3s 14ms/step - loss: -22915.2592
Epoch 33/40
182/182 [==============================] - 3s 14ms/step - loss: -23245.9947
Epoch 34/40
182/182 [==============================] - 3s 14ms/step - loss: -23003.6393
Epoch 35/40
182/182 [==============================] - 3s 14ms/step - loss: -23249.5969
Epoch 36/40
182/182 [==============================] - 3s 14ms/step - loss: -23100.3594
Epoch 37/40
182/182 [==============================] - 3s 14ms/step - loss: -23001.7339
Epoch 38/40
182/182 [==============================] - 3s 14ms/step - loss: -23129.2476
Epoch 39/40
182/182 [==============================] - 3s 14ms/step - loss: -23068.8846
Epoch 40/40
182/182 [==============================] - 3s 14ms/step - loss: -23374.5241

Calculate distance

[14]:
len_anterior = adata_spatial_anterior.shape[0]
len_rna = adata_rna.shape[0]
davae_emb = adata_integrate.obsm['X_davae']

adata_spatial_anterior.obsm["davae_embedding"] = davae_emb[0:len_anterior, :]
adata_rna.obsm['davae_embedding'] = davae_emb[len_anterior:len_rna+len_anterior, :]

distances_anterior = 1 - cosine_distances(
    adata_rna.obsm["davae_embedding"],
    adata_spatial_anterior.obsm['davae_embedding'],
)

Transfer label

[15]:
def label_transfer(dist, labels):
    lab = pd.get_dummies(labels).to_numpy().T
    class_prob = lab @ dist
    norm = np.linalg.norm(class_prob, 2, axis=0)
    class_prob = class_prob / norm
    class_prob = (class_prob.T - class_prob.min(1)) / class_prob.ptp(1)
    return class_prob

class_prob_anterior = label_transfer(distances_anterior, adata_rna.obs.cell_subclass)
cp_anterior_df = pd.DataFrame(
    class_prob_anterior,
    columns=np.sort(adata_rna.obs.cell_subclass.unique())
)
cp_anterior_df.index = adata_spatial_anterior.obs.index
adata_anterior_transfer = adata_spatial_anterior.copy()
adata_anterior_transfer.obs = pd.concat(
    [adata_spatial_anterior.obs, cp_anterior_df],
    axis=1
)

Visualize the neurons cortical layers

[16]:
sc.set_figure_params(facecolor="white", figsize=(2, 2))
sc.pl.spatial(
    adata_anterior_transfer,
    img_key="hires",
    color=["L2/3 IT", "L4", "L5 PT", "L6 CT"],
    size=1.5,
    color_map='Blues',
)
... storing 'feature_types' as categorical
... storing 'genome' as categorical
../_images/tutorials_davae_spatial_33_1.png