Higher-order interaction test
[1]:
%load_ext autoreload
%autoreload 2
from creme import creme
from creme import utils
from creme import custom_model
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import shuffle
2024-06-10 15:32:45.740168: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Load Enformer and example sequences
[2]:
data_dir = '../../../data'
track_index = [5111]
# TSS bin indeces
bins = [447, 448]
model = custom_model.Enformer(track_index=track_index, bin_index=bins)
2024-06-10 15:32:47.349063: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-10 15:32:48.819795: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 67919 MB memory: -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:85:00.0, compute capability: 8.0
[3]:
fasta_path = f'{data_dir}/GRCh38.primary_assembly.genome.fa'
seq_parser = utils.SequenceParser(fasta_path)
gene = 'CLUAP1_chr16_3501011_+'
gene_name, chrom, start, strand = gene.split('_')
wt_seq = seq_parser.extract_seq_centered(chrom, int(start), strand, model.seq_length)
wt = model.predict(wt_seq)[0,:,0].mean()
2024-06-10 15:32:53.943122: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-06-10 15:32:54.196676: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8401
Higher-order interaction test
In this example we will comprehensively analyze the context of the GATA2 gene using a greedy search for necessary tiles.
To run the test we need: - a loaded model - a list of tile coordinates to test - optimization - choice between [np.argmin, np.argmax] which determines if we will search for enhancers or silencers - num_shuffle - number of shuffled - number of iterations or tiles to shuffle
[4]:
perturb_window = 5000
N_shuffles = 10
tss_tile, cre_tiles = utils.set_tile_range(model.seq_length, perturb_window)
[5]:
len(cre_tiles)
[5]:
38
CLUAP1 gene context
[6]:
result_summary = creme.higher_order_interaction_test(model, wt_seq, cre_tiles.copy(),
np.argmin, num_shuffle=5, num_rounds=3)
0%| | 0/3 [00:00<?, ?it/s]
0%| | 0/38 [00:00<?, ?it/s]
3%|██ | 1/38 [00:03<01:54, 3.10s/it]
5%|████ | 2/38 [00:06<01:51, 3.10s/it]
8%|██████ | 3/38 [00:09<01:48, 3.10s/it]
11%|████████ | 4/38 [00:12<01:45, 3.10s/it]
13%|██████████ | 5/38 [00:15<01:42, 3.10s/it]
16%|████████████ | 6/38 [00:18<01:39, 3.10s/it]
18%|██████████████ | 7/38 [00:21<01:36, 3.11s/it]
21%|████████████████ | 8/38 [00:24<01:33, 3.10s/it]
24%|██████████████████ | 9/38 [00:27<01:29, 3.10s/it]
26%|███████████████████▋ | 10/38 [00:30<01:26, 3.10s/it]
29%|█████████████████████▋ | 11/38 [00:34<01:23, 3.10s/it]
32%|███████████████████████▋ | 12/38 [00:37<01:20, 3.11s/it]
34%|█████████████████████████▋ | 13/38 [00:40<01:17, 3.12s/it]
37%|███████████████████████████▋ | 14/38 [00:43<01:14, 3.12s/it]
39%|█████████████████████████████▌ | 15/38 [00:46<01:11, 3.12s/it]
42%|███████████████████████████████▌ | 16/38 [00:49<01:08, 3.13s/it]
45%|█████████████████████████████████▌ | 17/38 [00:52<01:05, 3.13s/it]
47%|███████████████████████████████████▌ | 18/38 [00:56<01:02, 3.13s/it]
50%|█████████████████████████████████████▌ | 19/38 [00:59<00:59, 3.13s/it]
53%|███████████████████████████████████████▍ | 20/38 [01:02<00:56, 3.13s/it]
55%|█████████████████████████████████████████▍ | 21/38 [01:05<00:53, 3.13s/it]
58%|███████████████████████████████████████████▍ | 22/38 [01:08<00:50, 3.13s/it]
61%|█████████████████████████████████████████████▍ | 23/38 [01:11<00:46, 3.13s/it]
63%|███████████████████████████████████████████████▎ | 24/38 [01:14<00:43, 3.13s/it]
66%|█████████████████████████████████████████████████▎ | 25/38 [01:17<00:40, 3.13s/it]
68%|███████████████████████████████████████████████████▎ | 26/38 [01:21<00:37, 3.13s/it]
71%|█████████████████████████████████████████████████████▎ | 27/38 [01:24<00:34, 3.13s/it]
74%|███████████████████████████████████████████████████████▎ | 28/38 [01:27<00:31, 3.13s/it]
76%|█████████████████████████████████████████████████████████▏ | 29/38 [01:30<00:28, 3.14s/it]
79%|███████████████████████████████████████████████████████████▏ | 30/38 [01:33<00:25, 3.14s/it]
82%|█████████████████████████████████████████████████████████████▏ | 31/38 [01:36<00:21, 3.14s/it]
84%|███████████████████████████████████████████████████████████████▏ | 32/38 [01:39<00:18, 3.14s/it]
87%|█████████████████████████████████████████████████████████████████▏ | 33/38 [01:43<00:15, 3.14s/it]
89%|███████████████████████████████████████████████████████████████████ | 34/38 [01:46<00:12, 3.14s/it]
92%|█████████████████████████████████████████████████████████████████████ | 35/38 [01:49<00:09, 3.14s/it]
95%|███████████████████████████████████████████████████████████████████████ | 36/38 [01:52<00:06, 3.14s/it]
97%|█████████████████████████████████████████████████████████████████████████ | 37/38 [01:55<00:03, 3.14s/it]
100%|███████████████████████████████████████████████████████████████████████████| 38/38 [01:58<00:00, 3.13s/it]
33%|█████████████████████████▎ | 1/3 [01:59<03:58, 119.38s/it]
0%| | 0/37 [00:00<?, ?it/s]
3%|██ | 1/37 [00:03<01:52, 3.13s/it]
5%|████ | 2/37 [00:06<01:49, 3.13s/it]
8%|██████▏ | 3/37 [00:09<01:46, 3.13s/it]
11%|████████▏ | 4/37 [00:12<01:43, 3.13s/it]
14%|██████████▎ | 5/37 [00:15<01:40, 3.13s/it]
16%|████████████▎ | 6/37 [00:18<01:36, 3.13s/it]
19%|██████████████▍ | 7/37 [00:21<01:33, 3.13s/it]
22%|████████████████▍ | 8/37 [00:25<01:30, 3.13s/it]
24%|██████████████████▍ | 9/37 [00:28<01:27, 3.13s/it]
27%|████████████████████▎ | 10/37 [00:31<01:24, 3.13s/it]
30%|██████████████████████▎ | 11/37 [00:34<01:21, 3.13s/it]
32%|████████████████████████▎ | 12/37 [00:37<01:18, 3.13s/it]
35%|██████████████████████████▎ | 13/37 [00:40<01:15, 3.13s/it]
38%|████████████████████████████▍ | 14/37 [00:43<01:11, 3.13s/it]
41%|██████████████████████████████▍ | 15/37 [00:46<01:08, 3.13s/it]
43%|████████████████████████████████▍ | 16/37 [00:50<01:05, 3.13s/it]
46%|██████████████████████████████████▍ | 17/37 [00:53<01:02, 3.13s/it]
49%|████████████████████████████████████▍ | 18/37 [00:56<00:59, 3.13s/it]
51%|██████████████████████████████████████▌ | 19/37 [00:59<00:56, 3.13s/it]
54%|████████████████████████████████████████▌ | 20/37 [01:02<00:53, 3.13s/it]
57%|██████████████████████████████████████████▌ | 21/37 [01:05<00:50, 3.13s/it]
59%|████████████████████████████████████████████▌ | 22/37 [01:08<00:46, 3.13s/it]
62%|██████████████████████████████████████████████▌ | 23/37 [01:12<00:43, 3.14s/it]
65%|████████████████████████████████████████████████▋ | 24/37 [01:15<00:40, 3.14s/it]
68%|██████████████████████████████████████████████████▋ | 25/37 [01:18<00:37, 3.14s/it]
70%|████████████████████████████████████████████████████▋ | 26/37 [01:21<00:34, 3.14s/it]
73%|██████████████████████████████████████████████████████▋ | 27/37 [01:24<00:31, 3.14s/it]
76%|████████████████████████████████████████████████████████▊ | 28/37 [01:27<00:28, 3.14s/it]
78%|██████████████████████████████████████████████████████████▊ | 29/37 [01:30<00:25, 3.14s/it]
81%|████████████████████████████████████████████████████████████▊ | 30/37 [01:34<00:22, 3.14s/it]
84%|██████████████████████████████████████████████████████████████▊ | 31/37 [01:37<00:18, 3.14s/it]
86%|████████████████████████████████████████████████████████████████▊ | 32/37 [01:40<00:15, 3.14s/it]
89%|██████████████████████████████████████████████████████████████████▉ | 33/37 [01:43<00:12, 3.15s/it]
92%|████████████████████████████████████████████████████████████████████▉ | 34/37 [01:46<00:09, 3.15s/it]
95%|██████████████████████████████████████████████████████████████████████▉ | 35/37 [01:49<00:06, 3.15s/it]
97%|████████████████████████████████████████████████████████████████████████▉ | 36/37 [01:52<00:03, 3.15s/it]
100%|███████████████████████████████████████████████████████████████████████████| 37/37 [01:56<00:00, 3.14s/it]
67%|██████████████████████████████████████████████████▋ | 2/3 [03:56<01:57, 117.78s/it]
0%| | 0/36 [00:00<?, ?it/s]
3%|██ | 1/36 [00:03<01:49, 3.14s/it]
6%|████▏ | 2/36 [00:06<01:47, 3.15s/it]
8%|██████▎ | 3/36 [00:09<01:43, 3.15s/it]
11%|████████▍ | 4/36 [00:12<01:40, 3.15s/it]
14%|██████████▌ | 5/36 [00:15<01:37, 3.15s/it]
17%|████████████▋ | 6/36 [00:18<01:34, 3.15s/it]
19%|██████████████▊ | 7/36 [00:22<01:31, 3.15s/it]
22%|████████████████▉ | 8/36 [00:25<01:28, 3.15s/it]
25%|███████████████████ | 9/36 [00:28<01:24, 3.14s/it]
28%|████████████████████▊ | 10/36 [00:31<01:22, 3.15s/it]
31%|██████████████████████▉ | 11/36 [00:34<01:18, 3.15s/it]
33%|█████████████████████████ | 12/36 [00:37<01:15, 3.15s/it]
36%|███████████████████████████ | 13/36 [00:40<01:12, 3.15s/it]
39%|█████████████████████████████▏ | 14/36 [00:44<01:09, 3.15s/it]
42%|███████████████████████████████▎ | 15/36 [00:47<01:06, 3.15s/it]
44%|█████████████████████████████████▎ | 16/36 [00:50<01:03, 3.15s/it]
47%|███████████████████████████████████▍ | 17/36 [00:53<00:59, 3.15s/it]
50%|█████████████████████████████████████▌ | 18/36 [00:56<00:56, 3.15s/it]
53%|███████████████████████████████████████▌ | 19/36 [00:59<00:53, 3.15s/it]
56%|█████████████████████████████████████████▋ | 20/36 [01:02<00:50, 3.15s/it]
58%|███████████████████████████████████████████▊ | 21/36 [01:06<00:47, 3.15s/it]
61%|█████████████████████████████████████████████▊ | 22/36 [01:09<00:44, 3.15s/it]
64%|███████████████████████████████████████████████▉ | 23/36 [01:12<00:40, 3.15s/it]
67%|██████████████████████████████████████████████████ | 24/36 [01:15<00:37, 3.15s/it]
69%|████████████████████████████████████████████████████ | 25/36 [01:18<00:34, 3.15s/it]
72%|██████████████████████████████████████████████████████▏ | 26/36 [01:21<00:31, 3.15s/it]
75%|████████████████████████████████████████████████████████▎ | 27/36 [01:25<00:28, 3.15s/it]
78%|██████████████████████████████████████████████████████████▎ | 28/36 [01:28<00:25, 3.15s/it]
81%|████████████████████████████████████████████████████████████▍ | 29/36 [01:31<00:22, 3.14s/it]
83%|██████████████████████████████████████████████████████████████▌ | 30/36 [01:34<00:18, 3.15s/it]
86%|████████████████████████████████████████████████████████████████▌ | 31/36 [01:37<00:15, 3.15s/it]
89%|██████████████████████████████████████████████████████████████████▋ | 32/36 [01:40<00:12, 3.15s/it]
92%|████████████████████████████████████████████████████████████████████▊ | 33/36 [01:43<00:09, 3.15s/it]
94%|██████████████████████████████████████████████████████████████████████▊ | 34/36 [01:47<00:06, 3.15s/it]
97%|████████████████████████████████████████████████████████████████████████▉ | 35/36 [01:50<00:03, 3.15s/it]
100%|███████████████████████████████████████████████████████████████████████████| 36/36 [01:53<00:00, 3.15s/it]
100%|████████████████████████████████████████████████████████████████████████████| 3/3 [05:50<00:00, 116.67s/it]
Result breakdown
During the first iteration, we start with the WT sequence. The initial prediction is the same as WT.
[7]:
result_summary[0]['initial_pred'] == wt.mean()
[7]:
True
In the first iteration all the tiles are shuffled producing 38 prediction values. The lowest TSS activity value will be selected in this iteration and saved as selected_mean_pred - in this case the point circled in black in panel 1. After this the tile is kept shuffled and iteration 2 starts with the remaining 37 tiles. In each iteration 1 tile is elminated and the plots show the selected tile in each case - the shuffling of which yields the lowest TSS activity.
[8]:
fig, axes = plt.subplots(1, len(result_summary.keys()), figsize=[15, 4], sharey=True)
shuffled_tiles = []
for i in result_summary.keys():
ax = axes[i]
print(f'Number of Tiles shuffled in iteration {i+1}: {result_summary[i]["preds"].shape[0]}')
mean_preds = result_summary[i]['preds'].mean(axis=-1).copy()
for s in shuffled_tiles:
mean_preds = np.insert(mean_preds, s, np.nan)
min_i = np.nanargmin(mean_preds)
shuffled_tiles.append(min_i)
sns.scatterplot(x=range(38), y=mean_preds/wt, color='purple', ax=ax)
ax.set_xlabel('Tile number')
axes[0].set_ylabel('Normalized prediction')
assert np.min(result_summary[i]['preds'].mean(axis=-1)) == result_summary[i]['selected_mean_pred']
sns.scatterplot(x=[min_i], y=[mean_preds[min_i]/wt], s=80, facecolors='none', edgecolors='k', linewidth=2, ax=ax)
ax.annotate(f'selected tile #{min_i}', [min_i+1.5, mean_preds[min_i]/wt]);
plt.tight_layout()
Number of Tiles shuffled in iteration 1: 38
Number of Tiles shuffled in iteration 2: 37
Number of Tiles shuffled in iteration 3: 36
Now let’s re-create the sequences in each iteration by following the tile order in ‘selected tile’ to shuffle putative CRE tiles, then obtain predictions and plot those versus the WT.
[9]:
model.bin_index = None # switch to saving all track not just TSS bins
wt_track = model.predict(wt_seq)[0,:,0]
[10]:
shuffled_tile_seq = wt_seq.copy()
per_iteration_preds = []
for i in result_summary.keys():
preds = []
tile_start, tile_end = result_summary[i]['selected_tile']
#shuffle this
for j in range(10): # number of shuffles
background_seq = shuffle.dinuc_shuffle(wt_seq)
shuffled_tile_seq[tile_start: tile_end] = background_seq[tile_start: tile_end]
preds.append(model.predict(shuffled_tile_seq)[0,:,0])
per_iteration_preds.append(preds)
[11]:
ax=utils.plot_track([wt_track], label='WT', color='k', alpha=1)
colors = ['purple', 'cyan', 'orange']
for i, preds in enumerate(per_iteration_preds):
utils.plot_track([np.array(preds).mean(axis=0)], zoom=[420, 480], label=f'iteration {i+1}', ax=ax, color=colors[i])
ax.set_ylim(0,100)
plt.legend();
[ ]: