Context swap test

[1]:
%load_ext autoreload
%autoreload 2
from creme import creme
from creme import utils
from creme import custom_model
import pandas as pd
import matplotlib.pyplot as plt
2024-06-10 15:32:02.506610: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Load Enformer and example sequences

[2]:
data_dir = "../../../data"
track_index = [5111]
model = custom_model.Enformer(track_index=track_index)
2024-06-10 15:32:03.877407: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-10 15:32:05.259255: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 69489 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:85:00.0, compute capability: 8.0
[ ]:
fasta_path = f'{data_dir}/GRCh38.primary_assembly.genome.fa'
seq_parser = utils.SequenceParser(fasta_path)

genes = ['ABCA8_chr17_68955392_-', 'NFKBIZ_chr3_101849513_+']
gene_seqs = {}

for gene in genes:
    gene_name, chrom, start, strand = gene.split('_')
    seq = seq_parser.extract_seq_centered(chrom, int(start), strand, model.seq_length)
    gene_seqs[gene_name] = seq
[ ]:
# TSS bin indeces
bins = [447, 448]
[ ]:
abca8_wt = model.predict(gene_seqs['ABCA8'])[0,:,0]
nfkbiz_wt = model.predict(gene_seqs['NFKBIZ'])[0,:,0]
[ ]:
utils.plot_track([abca8_wt], color='green', zoom=[0, 896], marks=bins)
utils.plot_track([nfkbiz_wt], color='red', zoom=[0, 896], marks=bins)

Context swap test In this example we will swap the TSS tiles (5Kb) of an enhancing context sequence with a silencing one and vice versa.

To run context swap test we need: - a loaded model - onehot encoded sequence (WT) of the source from which the TSS is taken - onehot encoded sequence (WT) of the target into which the TSS is embedded - a coordinate interval where the TSS is

[ ]:
seq_halflen = model.seq_length // 2
half_window_size = 2500
N_shuffles = 10

TSS (from an enhancing sequence) embedded in a silencing context

[ ]:
pred_mut = creme.context_swap_test(model, gene_seqs['ABCA8'], gene_seqs['NFKBIZ'],
                                   [seq_halflen - half_window_size, seq_halflen + half_window_size])
[ ]:
ax=utils.plot_track([abca8_wt], color='green', zoom=[400, 500], alpha=1, label='WT')
utils.plot_track(pred_mut[:,:,0], color='purple', zoom=[400, 500], ax=ax, alpha=1, label='WT TSS + \ncontext from \nsilencing sequence')
plt.legend()

TSS (from a silencing sequence) embedded in an enhancing context

[ ]:
pred_mut = creme.context_swap_test(model, gene_seqs['NFKBIZ'], gene_seqs['ABCA8'],
                                   [seq_halflen - half_window_size, seq_halflen + half_window_size])
[ ]:
ax=utils.plot_track([nfkbiz_wt], color='red', zoom=[400, 500], alpha=1, label='WT')
utils.plot_track(pred_mut[:,:,0], color='purple', zoom=[400, 500], ax=ax, alpha=1, label='WT TSS + \ncontext from \nenhancing sequence')
plt.legend()
[ ]: